├── .gitignore ├── downstream_tasks ├── i2b2_preprocessing │ ├── i2b2_2006_deid │ │ ├── split_train_dev.py │ │ └── to_conll.py │ ├── i2b2_2012 │ │ └── Reformat.ipynb │ └── i2b2_2014_deid_hf_risk │ │ └── Reformat.ipynb ├── run_classifier.sh ├── ner_eval │ ├── score_i2b2.py │ ├── ner_detokenize.py │ ├── format_for_i2b2_eval.py │ └── conlleval.pl ├── run_i2b2.sh ├── run_classifier.py └── run_ner.py ├── lm_pretraining ├── create_pretrain_data.sh ├── finetune_lm_tf.sh ├── format_mimic_for_BERT.py ├── heuristic_tokenize.py ├── create_pretraining_data.py └── run_pretraining.py ├── LICENSE ├── README.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] -------------------------------------------------------------------------------- /downstream_tasks/i2b2_preprocessing/i2b2_2006_deid/split_train_dev.py: -------------------------------------------------------------------------------- 1 | with open('train_dev.conll', 'r') as f: 2 | sents = f.read().strip().split('\n\n') 3 | 4 | print(len(sents)) 5 | import random 6 | 7 | random.seed(555) 8 | 9 | n = len(sents) 10 | ind = int(0.7*n) 11 | 12 | train_sents = sents[:ind] 13 | dev_sents = sents[ind:] 14 | 15 | print(len(train_sents)) 16 | with open('train.conll', 'w') as f: 17 | f.write('\n\n'.join(train_sents)) 18 | 19 | print(len(dev_sents)) 20 | with open('dev.conll', 'w') as f: 21 | f.write('\n\n'.join(dev_sents)) 22 | -------------------------------------------------------------------------------- /lm_pretraining/create_pretrain_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BERT_BASE_DIR=/PATH/TO/BERT/VOCAB/FILE #modify this to bert or biobert folder containing a vocab.txt file 4 | DATA_DIR=/PATH/TO/TOKENIZED/NOTES #modify this to be the path to the tokenized data 5 | OUTPUT_DIR=/PATH/TO/OUTPUT/DIR # modify this to be your output directory path 6 | 7 | 8 | #modify this to be the note type that you want to create pretraining data for - e.g. ecg, echo, radiology, physician, nursing, etc. 9 | # Note that you can also specify multiple input files & output files below 10 | DATA_FILE=nursing_other 11 | 12 | 13 | # Note that create_pretraining_data.py is unmodified from the script in the original BERT repo. 14 | # Refer to the BERT repo for the most up to date version of this code. 15 | python create_pretraining_data.py \ 16 | --input_file=$DATA_DIR/$DATA_FILE.txt \ 17 | --output_file=$OUTPUT_DIR/$DATA_FILE.tfrecord \ 18 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 19 | --do_lower_case=False \ 20 | --max_seq_length=128 \ 21 | --max_predictions_per_seq=22 \ 22 | --masked_lm_prob=0.15 \ 23 | --random_seed=12345 \ 24 | --dupe_factor=5 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Emily Alsentzer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /downstream_tasks/run_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Example script for running run_classifier.py 4 | 5 | 6 | for EPOCHS in 3 4 5 ; do 7 | for LR in 2e-5 3e-5 5e-5; do 8 | for BATCH_SZ in 16 32 ; do 9 | MAX_SEQ_LEN=150 10 | 11 | DATA_DIR=PATH/TO/MEDNLI/DATA/ #Modify this to be the path to the MedNLI data 12 | OUTPUT_DIR=PATH/TO/OUTPUT/DIR/ #Modify this to be the path to your output directory 13 | CLINICAL_BERT_LOC=PATH/TO/CLINICAL/BERT/MODEL #Modify this to be the path to the clinical BERT model 14 | 15 | echo $OUTPUT_DIR 16 | 17 | BERT_MODEL=clinical_bert # You can change this to biobert or bert-base-cased 18 | 19 | mkdir -p $OUTPUT_DIR 20 | 21 | python run_classifier.py \ 22 | --data_dir=$DATA_DIR \ 23 | --bert_model=$BERT_MODEL \ 24 | --model_loc $CLINICAL_BERT_LOC \ 25 | --task_name mednli \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_test \ 29 | --output_dir=$OUTPUT_DIR \ 30 | --num_train_epochs $EPOCHS \ 31 | --learning_rate $LR \ 32 | --train_batch_size $BATCH_SZ \ 33 | --max_seq_length $MAX_SEQ_LEN \ 34 | --gradient_accumulation_steps 2 35 | done 36 | done 37 | done 38 | 39 | -------------------------------------------------------------------------------- /lm_pretraining/finetune_lm_tf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # location of bert or biobert model 4 | # Note that if you use biobert as your base model, you'll need to change init_checkpoint to be biobert_model.ckpt 5 | BERT_BASE_DIR=/PATH/TO/BERT/MODEL 6 | 7 | # folder where you want to save your clinical BERT model 8 | OUTPUT_DIR=/PATH/TO/CLINICAL/BERT/OUTPUT/DIR 9 | 10 | # folder that contains the tfrecords - this will be the output directory from create_pretrain_data.sh 11 | INPUT_FILES_DIR=/PATH/TO/TFRECORDS 12 | 13 | NUM_TRAIN_STEPS=100000 14 | NUM_WARMUP_STEPS=10000 15 | LR=5e-5 16 | 17 | # This example illustrates the training of Bio+Discharge Summary BERT. If you change the input_file 18 | # to the tfrecords for all MIMIC sections - e.g. 19 | # --input_file=../data/tf_records/discharge_summary.tfrecord,../data/tf_records/physician.tfrecord,../data/tf_records/nursing.tfrecord,../data/tf_records/nursing_other.tfrecord,../data/tf_records/radiology.tfrecord,../data/tf_records/general.tfrecord,../data/tf_records/respiratory.tfrecord,../data/tf_records/consult.tfrecord,../data/tf_records/nutrition.tfrecord,../data/tf_records/case_management.tfrecord,../data/tf_records/pharmacy.tfrecord,../data/tf_records/rehab_services.tfrecord,../data/tf_records/social_work.tfrecord,../data/tf_records/ecg.tfrecord,../data/tf_records/echo.tfrecord \ 20 | 21 | python run_pretraining.py \ 22 | --output_dir=$OUTPUT_DIR \ 23 | --input_file=$INPUT_FILES_DIR/discharge_summary.tfrecord \ 24 | --do_train=True \ 25 | --do_eval=True \ 26 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 27 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 28 | --train_batch_size=32 \ 29 | --max_seq_length=128 \ 30 | --max_predictions_per_seq=20 \ 31 | --num_train_steps=$NUM_TRAIN_STEPS \ 32 | --num_warmup_steps=$NUM_WARMUP_STEPS \ 33 | --learning_rate=$LR \ 34 | --save_checkpoints_steps=50000 \ 35 | --keep_checkpoint_max=15 36 | 37 | -------------------------------------------------------------------------------- /downstream_tasks/i2b2_preprocessing/i2b2_2006_deid/to_conll.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import xml.etree.ElementTree as ET 4 | 5 | 6 | i = 0 7 | with open(sys.argv[1]) as f: 8 | for line in f.readlines(): 9 | # filter out non-lines 10 | if line.startswith( '') or \ 11 | line.startswith(''): 12 | continue 13 | 14 | # Parse PHI type/tokens 15 | regex = '((.*?))' 16 | phi_tags = re.findall(regex, line) 17 | for tag in phi_tags: 18 | line = line.replace(tag[0], '__phi__').strip() 19 | 20 | # Walk through sentence 21 | phi_ind = 0 22 | for w in line.split(): 23 | if w == '__phi__': 24 | phi = phi_tags[phi_ind] 25 | tag = phi[1] 26 | toks = phi[2].split() 27 | print(toks[0], 'B-%s'%tag) 28 | for t in toks[1:]: 29 | print(t, 'I-%s'%tag) 30 | phi_ind += 1 31 | # Two elif statements check for edge cases with Dates 32 | elif w.startswith('__phi__'): 33 | # examples like following format: 34 | # 01/01/1995 or 01-01-95 35 | phi = phi_tags[phi_ind] 36 | tag = phi[1] 37 | toks = phi[2].split() 38 | print(toks[0], 'B-%s'%tag) 39 | if w[7:8] == '/' or w[7:8] == '-': 40 | print(w[8:], 'O') # remove the / or - in the year 41 | else: 42 | print(w[7:], 'O') 43 | phi_ind += 1 44 | elif w.endswith('__phi__'): 45 | # 19950101 46 | phi = phi_tags[phi_ind] 47 | tag = phi[1] 48 | toks = phi[2].split() 49 | print(w[:-7], 'O') 50 | print(toks[0], 'B-%s'%tag) 51 | phi_ind += 1 52 | else: 53 | print(w, 'O') 54 | print() 55 | i+=1 56 | 57 | 58 | -------------------------------------------------------------------------------- /downstream_tasks/ner_eval/score_i2b2.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import re 5 | 6 | 7 | parser = argparse.ArgumentParser(description='') 8 | parser.add_argument('--input_pred_dir', type=str, help='Location of input i2b2 prediction formatted files') 9 | parser.add_argument('--input_gold_dir', type=str, help='Location of input i2b2 gold formatted files') 10 | parser.add_argument('--output_dir', type=str, help='Location of output files') 11 | 12 | args = parser.parse_args() 13 | 14 | def format_line(line): 15 | text, s_line, s_word, e_line, e_word, label = re.findall('c="(.*)" ([0-9]+):([0-9]+) ([0-9]+):([0-9]+)\|\|t="(.*)"', line)[0] 16 | return ({'text':text, 's_line':s_line, 's_word':s_word, 'e_line':e_line, 'e_word':e_word, 'label':label}) 17 | # Example format: c="cortical-type symptoms" 820:6 820:7||t="Problem" 18 | 19 | 20 | pred_lines = [] 21 | with open(os.path.join(args.input_pred_dir, "i2b2.con"),'r') as pred_f: 22 | for line in pred_f: 23 | pred_lines.append(format_line(line)) 24 | 25 | gold_lines = [] 26 | with open(os.path.join(args.input_gold_dir, "i2b2.con"),'r') as gold_f: 27 | for line in gold_f: 28 | gold_lines.append(format_line(line)) 29 | 30 | def in_gold(pred, gold_lines): 31 | for gold in gold_lines: 32 | if gold == pred: 33 | return True 34 | return False 35 | 36 | def in_pred(gold, pred_lines): 37 | for pred in pred_lines: 38 | if gold == pred: 39 | return True 40 | return False 41 | 42 | 43 | precCount = 0 44 | for pred in pred_lines: 45 | if in_gold(pred, gold_lines): 46 | precCount += 1 47 | 48 | 49 | recallCount = 0 50 | for gold in gold_lines: 51 | if in_pred(gold, pred_lines): 52 | recallCount += 1 53 | 54 | # totalEvents: total number of Events in the first file 55 | # recall total number of Events in the gold file that can be found in the pred file 56 | # precision total number of Events in the pred file that can be found in the gold file 57 | 58 | systemEventCount = len(pred_lines) 59 | goldEventCount = len(gold_lines) 60 | print('Predicted events: %d, Gold events: %d' %(systemEventCount, goldEventCount)) 61 | precision=float(precCount)/systemEventCount 62 | recall=float(recallCount)/goldEventCount 63 | fScore=2*(precision*recall)/(precision+recall) 64 | print('Exact Precision: %0.5f, Recall: %0.5f, F1: %0.5f' %(precision, recall, fScore)) 65 | 66 | with open(os.path.join(args.output_dir, "final_results.txt"),'w') as writer: 67 | writer.write('Precision\tRecall\tF1\n') 68 | writer.write('%0.5f\t%0.5f\t%0.5f\n' %(precision, recall, fScore)) 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # clinicalBERT 2 | Repository for [Publicly Available Clinical BERT Embeddings](https://www.aclweb.org/anthology/W19-1909/) (NAACL Clinical NLP Workshop 2019) 3 | 4 | ## Using Clinical BERT 5 | 6 | UPDATE: You can now use ClinicalBERT directly through the [transformers](https://github.com/huggingface/transformers) library. Check out the [Bio+Clinical BERT](https://huggingface.co/emilyalsentzer/Bio_ClinicalBERT) and [Bio+Discharge Summary BERT](https://huggingface.co/emilyalsentzer/Bio_Discharge_Summary_BERT) model pages for instructions on how to use the models within the Transformers library. 7 | 8 | ## Download Clinical BERT 9 | 10 | The Clinical BERT models can also be downloaded [here](https://www.dropbox.com/s/8armk04fu16algz/pretrained_bert_tf.tar.gz?dl=0), or via 11 | 12 | ``` 13 | wget -O pretrained_bert_tf.tar.gz https://www.dropbox.com/s/8armk04fu16algz/pretrained_bert_tf.tar.gz?dl=1 14 | ``` 15 | 16 | `biobert_pretrain_output_all_notes_150000` corresponds to Bio+Clinical BERT, and `biobert_pretrain_output_disch_100000` corresponds to Bio+Discharge Summary BERT. Both models are finetuned from [BioBERT](https://arxiv.org/abs/1901.08746). We specifically use the [BioBERT-Base v1.0 (+ PubMed 200K + PMC 270K)](https://github.com/naver/biobert-pretrained) version of BioBERT. 17 | 18 | `bert_pretrain_output_all_notes_150000` corresponds to Clinical BERT, and `bert_pretrain_output_disch_100000` corresponds to Discharge Summary BERT. Both models are finetuned from the cased version of BERT, specifically cased_L-12_H-768_A-12. 19 | 20 | ## Reproduce Clinical BERT 21 | #### Pretraining 22 | To reproduce the steps necessary to finetune BERT or BioBERT on MIMIC data, follow the following steps: 23 | 1. Run `format_mimic_for_BERT.py` - Note you'll need to change the file paths at the top of the file. 24 | 2. Run `create_pretrain_data.sh` 25 | 3. Run `finetune_lm_tf.sh` 26 | 27 | Note: See issue [#4](https://github.com/EmilyAlsentzer/clinicalBERT/issues/4) for ways to improve section splitting code. 28 | 29 | #### Downstream Tasks 30 | To see an example of how to use clinical BERT for the Med NLI tasks, go to the `run_classifier.sh` script in the downstream_tasks folder. To see an example for NER tasks, go to the `run_i2b2.sh` script. 31 | 32 | ## Contact 33 | Please post a Github issue or contact emilya@mit.edu if you have any questions. 34 | 35 | ## Citation 36 | Please acknowledge the following work in papers or derivative software: 37 | 38 | Emily Alsentzer, John Murphy, William Boag, Wei-Hung Weng, Di Jin, Tristan Naumann, and Matthew McDermott. 2019. Publicly available clinical BERT embeddings. In Proceedings of the 2nd Clinical Natural Language Processing Workshop, pages 72-78, Minneapolis, Minnesota, USA. Association for Computational Linguistics. 39 | 40 | ``` 41 | @inproceedings{alsentzer-etal-2019-publicly, 42 | title = "Publicly Available Clinical {BERT} Embeddings", 43 | author = "Alsentzer, Emily and 44 | Murphy, John and 45 | Boag, William and 46 | Weng, Wei-Hung and 47 | Jin, Di and 48 | Naumann, Tristan and 49 | McDermott, Matthew", 50 | booktitle = "Proceedings of the 2nd Clinical Natural Language Processing Workshop", 51 | month = jun, 52 | year = "2019", 53 | address = "Minneapolis, Minnesota, USA", 54 | publisher = "Association for Computational Linguistics", 55 | url = "https://www.aclweb.org/anthology/W19-1909", 56 | doi = "10.18653/v1/W19-1909", 57 | pages = "72--78" 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /downstream_tasks/run_i2b2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #update these to the path of the bert model and the path to the NER data 4 | BERT_DIR=/PATH/TO/BERT/MODEL #clinical bert, biobert, or bert 5 | 6 | #path to NER data. Make sure to preprocess data according to BIO format first. 7 | #You could use the scripts in the i2b2_preprocessing folder to preprocess the i2b2 data 8 | NER_DIR=/PATH/TO/NER/DATA/DIRECTORY 9 | 10 | 11 | for EPOCHS in 2 3 4 ; do 12 | for LEARN_RATE in 2e-5 3e-5 5e-5 ; do 13 | for BATCH_SZ in 16 32 ; do 14 | 15 | 16 | OUTPUT_DIR=/PATH/TO/OUTPUT/DIRECTORY #update this to the output directory you want 17 | mkdir -p $OUTPUT_DIR 18 | 19 | # You can change the task_name to 'i2b2_2014', 'i2b2_2010', 'i2b2_2006', or 'i2b2_2012' 20 | # Note that you may need to modify the DataProcessor code in `run_ner.py` to adapt to the format of your input 21 | # If you want to use biobert, change the init_checkpoint to biobert_model.ckpt 22 | # run_ner.py is adapted from kyzhouhzau's BERT-NER github and the BioBERT repo 23 | python run_ner.py \ 24 | --do_train=True \ 25 | --do_eval=True \ 26 | --do_predict=True \ 27 | --task_name='i2b2_2006' \ 28 | --vocab_file=$BERT_DIR/vocab.txt \ 29 | --bert_config_file=$BERT_DIR/bert_config.json \ 30 | --init_checkpoint=$BERT_DIR/bert_model.ckpt \ 31 | --num_train_epochs=$EPOCHS \ 32 | --learning_rate=$LEARN_RATE \ 33 | --train_batch_size=$BATCH_SZ \ 34 | --max_seq_length=150 \ 35 | --data_dir=$NER_DIR \ 36 | --output_dir=$OUTPUT_DIR \ 37 | --save_checkpoints_steps=2000 38 | 39 | 40 | # Note here we're performing 10 fold CV, but if you want to recover the original train, val, test split, use CV iter = 9 41 | # Also go to run_ner.py & modify line 738-739 so that you only run the last CV iteration. 42 | for CV_ITER in 0 1 2 3 4 5 6 7 8 9 ; do 43 | for MODE in eval test ; do 44 | EVAL_OUTPUT_DIR=$OUTPUT_DIR/$CV_ITER #creates a new folder for each CV iteration 45 | mkdir -p $EVAL_OUTPUT_DIR 46 | mkdir -p $EVAL_OUTPUT_DIR/$MODE 47 | mkdir -p $EVAL_OUTPUT_DIR/$MODE/gold/ 48 | mkdir -p $EVAL_OUTPUT_DIR/$MODE/pred/ 49 | 50 | OUTPUT_FILE=${EVAL_OUTPUT_DIR}/NER_result_conll_${MODE}.txt 51 | 52 | #convert word-piece BERT NER results to CoNLL eval format 53 | #Code is adapted from the BioBERT github 54 | python ner_eval/ner_detokenize.py \ 55 | --token_test_path=${OUTPUT_DIR}/${CV_ITER}_token_${MODE}.txt \ 56 | --label_test_path=${OUTPUT_DIR}/${CV_ITER}_label_${MODE}.txt \ 57 | --answer_path=${NER_DIR}/${CV_ITER}_${MODE} \ 58 | --tok_to_orig_map_path=${OUTPUT_DIR}/${CV_ITER}_tok_to_orig_map_${MODE}.txt \ 59 | --output_file=$OUTPUT_FILE 60 | 61 | #convert to i2b2 evaluation format (adapted from Cliner Repo) 62 | python ner_eval/format_for_i2b2_eval.py \ 63 | --results_file $OUTPUT_FILE \ 64 | --output_gold_dir $EVAL_OUTPUT_DIR/$MODE/gold/ \ 65 | --output_pred_dir $EVAL_OUTPUT_DIR/$MODE/pred/ 66 | 67 | # evaluate performance on i2b2 tasks 68 | python ner_eval/score_i2b2.py \ 69 | --input_gold_dir $EVAL_OUTPUT_DIR/$MODE/gold/ \ 70 | --input_pred_dir $EVAL_OUTPUT_DIR/$MODE/pred/ \ 71 | --output_dir $EVAL_OUTPUT_DIR/$MODE 72 | done 73 | done 74 | done 75 | done 76 | done 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /lm_pretraining/format_mimic_for_BERT.py: -------------------------------------------------------------------------------- 1 | import psycopg2 2 | import pandas as pd 3 | import sys 4 | import spacy 5 | import re 6 | import stanfordnlp 7 | import time 8 | import scispacy 9 | from tqdm import tqdm 10 | from heuristic_tokenize import sent_tokenize_rules 11 | 12 | 13 | # update these constants to run this script 14 | OUTPUT_DIR = '/PATH/TO/OUTPUT/DIR' #this path will contain tokenized notes. This dir will be the input dir for create_pretrain_data.sh 15 | MIMIC_NOTES_FILE = 'PATH/TO/MIMIC/DATA' #this is the path to mimic data if you're reading from a csv. Else uncomment the code to read from database below 16 | 17 | 18 | #setting sentence boundaries 19 | def sbd_component(doc): 20 | for i, token in enumerate(doc[:-2]): 21 | # define sentence start if period + titlecase token 22 | if token.text == '.' and doc[i+1].is_title: 23 | doc[i+1].sent_start = True 24 | if token.text == '-' and doc[i+1].text != '-': 25 | doc[i+1].sent_start = True 26 | return doc 27 | 28 | #convert de-identification text into one token 29 | def fix_deid_tokens(text, processed_text): 30 | deid_regex = r"\[\*\*.{0,15}.*?\*\*\]" 31 | if text: 32 | indexes = [m.span() for m in re.finditer(deid_regex,text,flags=re.IGNORECASE)] 33 | else: 34 | indexes = [] 35 | for start,end in indexes: 36 | processed_text.merge(start_idx=start,end_idx=end) 37 | return processed_text 38 | 39 | 40 | def process_section(section, note, processed_sections): 41 | # perform spacy processing on section 42 | processed_section = nlp(section['sections']) 43 | processed_section = fix_deid_tokens(section['sections'], processed_section) 44 | processed_sections.append(processed_section) 45 | 46 | def process_note_helper(note): 47 | # split note into sections 48 | note_sections = sent_tokenize_rules(note) 49 | processed_sections = [] 50 | section_frame = pd.DataFrame({'sections':note_sections}) 51 | section_frame.apply(process_section, args=(note,processed_sections,), axis=1) 52 | return(processed_sections) 53 | 54 | def process_text(sent, note): 55 | sent_text = sent['sents'].text 56 | if len(sent_text) > 0 and sent_text.strip() != '\n': 57 | if '\n' in sent_text: 58 | sent_text = sent_text.replace('\n', ' ') 59 | note['text'] += sent_text + '\n' 60 | 61 | def get_sentences(processed_section, note): 62 | # get sentences from spacy processing 63 | sent_frame = pd.DataFrame({'sents': list(processed_section['sections'].sents)}) 64 | sent_frame.apply(process_text, args=(note,), axis=1) 65 | 66 | def process_note(note): 67 | try: 68 | note_text = note['text'] #unicode(note['text']) 69 | note['text'] = '' 70 | processed_sections = process_note_helper(note_text) 71 | ps = {'sections': processed_sections} 72 | ps = pd.DataFrame(ps) 73 | ps.apply(get_sentences, args=(note,), axis=1) 74 | return note 75 | except Exception as e: 76 | pass 77 | #print ('error', e) 78 | 79 | 80 | 81 | if len(sys.argv) < 2: 82 | print('Please specify the note category.') 83 | sys.exit() 84 | 85 | category = sys.argv[1] 86 | 87 | 88 | start = time.time() 89 | tqdm.pandas() 90 | 91 | print('Begin reading notes') 92 | 93 | 94 | # Uncomment this to use postgres to query mimic instead of reading from a file 95 | # con = psycopg2.connect(dbname='mimic', host="/var/run/postgresql") 96 | # notes_query = "(select * from mimiciii.noteevents);" 97 | # notes = pd.read_sql_query(notes_query, con) 98 | notes = pd.read_csv(MIMIC_NOTES_FILE, index_col = 0) 99 | #print(set(notes['category'])) # all categories 100 | 101 | 102 | notes = notes[notes['category'] == category] 103 | print('Number of notes: %d' %len(notes.index)) 104 | notes['ind'] = list(range(len(notes.index))) 105 | 106 | # NOTE: `disable=['tagger', 'ner'] was added after paper submission to make this process go faster 107 | # our time estimate in the paper did not include the code to skip spacy's NER & tagger 108 | nlp = spacy.load('en_core_sci_md', disable=['tagger','ner']) 109 | nlp.add_pipe(sbd_component, before='parser') 110 | 111 | 112 | formatted_notes = notes.progress_apply(process_note, axis=1) 113 | with open(OUTPUT_DIR + category + '.txt','w') as f: 114 | for text in formatted_notes['text']: 115 | if text != None and len(text) != 0 : 116 | f.write(text) 117 | f.write('\n') 118 | 119 | end = time.time() 120 | print (end-start) 121 | print ("Done formatting notes") 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | absl-py=0.7.0=pypi_0 5 | asn1crypto=0.24.0=py36_0 6 | astor=0.7.1=pypi_0 7 | awscli=1.16.111=pypi_0 8 | backcall=0.1.0=py36_0 9 | blas=1.0=mkl 10 | bleach=3.1.0=py36_0 11 | boto3=1.9.86=pypi_0 12 | botocore=1.12.101=pypi_0 13 | ca-certificates=2018.11.29=ha4d7672_0 14 | certifi=2018.11.29=py36_1000 15 | cffi=1.11.5=py36he75722e_1 16 | chardet=3.0.4=py36_1 17 | colorama=0.3.9=pypi_0 18 | conllu=1.2.2=pypi_0 19 | cryptography=2.4.2=py36h1ba5d50_0 20 | cymem=2.0.2=py36hfd86e86_0 21 | cytoolz=0.9.0.1=py36h14c3975_1 22 | dbus=1.13.6=h746ee38_0 23 | decorator=4.3.0=py36_0 24 | dill=0.2.8.2=py36_0 25 | docutils=0.14=pypi_0 26 | en-core-sci-md=0.1.0=pypi_0 27 | en-core-sci-sm=0.1.0=pypi_0 28 | en-core-web-sm=2.0.0=pypi_0 29 | entrypoints=0.3=py36_0 30 | expat=2.2.6=he6710b0_0 31 | fontconfig=2.13.0=h9420a91_0 32 | freetype=2.9.1=h8a8886c_1 33 | gast=0.2.2=pypi_0 34 | glib=2.56.2=hd408876_0 35 | gmp=6.1.2=h6c8ec71_1 36 | grpcio=1.18.0=pypi_0 37 | gst-plugins-base=1.14.0=hbbd80ab_1 38 | gstreamer=1.14.0=hb453b48_1 39 | h5py=2.9.0=pypi_0 40 | icu=58.2=h9c2bf20_1 41 | idna=2.8=py36_0 42 | intel-openmp=2019.1=144 43 | ipykernel=5.1.0=py36h39e3cac_0 44 | ipython=7.2.0=py36h39e3cac_0 45 | ipython_genutils=0.2.0=py36_0 46 | ipywidgets=7.4.2=py36_0 47 | jedi=0.13.2=py36_0 48 | jinja2=2.10=py36_0 49 | jmespath=0.9.3=pypi_0 50 | jpeg=9b=h024ee3a_2 51 | jsonschema=2.6.0=py36_0 52 | jupyter=1.0.0=py36_7 53 | jupyter-contrib-core=0.3.3=pypi_0 54 | jupyter-highlight-selected-word=0.2.0=pypi_0 55 | jupyter-latex-envs=1.4.6=pypi_0 56 | jupyter-nbextensions-configurator=0.4.1=pypi_0 57 | jupyter_client=5.2.4=py36_0 58 | jupyter_console=6.0.0=py36_0 59 | jupyter_contrib_core=0.3.3=py_2 60 | jupyter_contrib_nbextensions=0.5.1=py36_0 61 | jupyter_core=4.4.0=py36_0 62 | jupyter_highlight_selected_word=0.2.0=py36_1000 63 | jupyter_latex_envs=1.4.4=py36_1000 64 | jupyter_nbextensions_configurator=0.4.1=py36_0 65 | keras-applications=1.0.7=pypi_0 66 | keras-preprocessing=1.0.9=pypi_0 67 | krb5=1.16.1=h173b8e3_7 68 | libedit=3.1.20181209=hc058e9b_0 69 | libffi=3.2.1=hd88cf55_4 70 | libgcc-ng=8.2.0=hdf63c60_1 71 | libgfortran-ng=7.3.0=hdf63c60_0 72 | libpng=1.6.36=hbc83047_0 73 | libpq=11.1=h20c2e04_0 74 | libsodium=1.0.16=h1bed415_0 75 | libstdcxx-ng=8.2.0=hdf63c60_1 76 | libuuid=1.0.3=h1bed415_2 77 | libxcb=1.13=h1bed415_1 78 | libxml2=2.9.9=he19cac6_0 79 | libxslt=1.1.32=h4785a14_1002 80 | lxml=4.3.0=pypi_0 81 | markdown=3.0.1=pypi_0 82 | markupsafe=1.1.0=py36h7b6447c_0 83 | mistune=0.8.4=py36h7b6447c_0 84 | mkl=2018.0.3=1 85 | mkl_fft=1.0.6=py36h7dd41cf_0 86 | mkl_random=1.0.1=py36h4414c95_1 87 | msgpack-numpy=0.4.3.2=py36_0 88 | msgpack-python=0.5.6=py36h6bb024c_1 89 | murmurhash=1.0.1=py36he6710b0_0 90 | nbconvert=5.3.1=py36_0 91 | nbformat=4.4.0=py36_0 92 | ncurses=6.1=he6710b0_1 93 | nltk=3.4=py36_1 94 | notebook=5.7.4=py36_0 95 | numpy=1.15.4=py36h1d66e8a_0 96 | numpy-base=1.15.4=py36h81de0dd_0 97 | openssl=1.1.1a=h14c3975_1000 98 | pandas=0.24.0=py36he6710b0_0 99 | pandoc=2.2.3.2=0 100 | pandocfilters=1.4.2=py36_1 101 | parso=0.3.1=py36_0 102 | pcre=8.42=h439df22_0 103 | pexpect=4.6.0=py36_0 104 | pickleshare=0.7.5=py36_0 105 | pip=18.1=py36_0 106 | plac=0.9.6=py36_0 107 | preshed=2.0.1=py36he6710b0_0 108 | prometheus_client=0.5.0=py36_0 109 | prompt_toolkit=2.0.7=py36_0 110 | protobuf=3.6.1=pypi_0 111 | psutil=5.5.0=pypi_0 112 | psycopg2=2.7.6.1=py36h1ba5d50_0 113 | ptyprocess=0.6.0=py36_0 114 | pyasn1=0.4.5=pypi_0 115 | pycparser=2.19=py36_0 116 | pygments=2.3.1=py36_0 117 | pyopenssl=18.0.0=py36_0 118 | pyqt=5.9.2=py36h05f1152_2 119 | pysocks=1.6.8=py36_0 120 | python=3.6.8=h0371630_0 121 | python-dateutil=2.7.5=py36_0 122 | pytorch-pretrained-bert=0.4.0=pypi_0 123 | pytz=2018.9=py36_0 124 | pyyaml=3.13=pypi_0 125 | pyzmq=17.1.2=py36h14c3975_0 126 | qt=5.9.7=h5867ecd_1 127 | qtconsole=4.4.3=py36_0 128 | readline=7.0=h7b6447c_5 129 | regex=2018.1.10=pypi_0 130 | requests=2.21.0=py36_0 131 | rsa=3.4.2=pypi_0 132 | s3transfer=0.1.13=pypi_0 133 | scipy=1.2.1=pypi_0 134 | scispacy=0.1.0=pypi_0 135 | send2trash=1.5.0=py36_0 136 | setuptools=40.6.3=py36_0 137 | sip=4.19.8=py36hf484d3e_0 138 | six=1.12.0=py36_0 139 | spacy=2.0.18=pypi_0 140 | sqlite=3.26.0=h7b6447c_0 141 | stanfordcorenlp=3.9.1.1=pypi_0 142 | stanfordnlp=0.1.0=pypi_0 143 | statistics=1.0.3.5=pypi_0 144 | tensorboard=1.12.2=pypi_0 145 | tensorflow-gpu=1.12.0=pypi_0 146 | termcolor=1.1.0=pypi_0 147 | terminado=0.8.1=py36_1 148 | testpath=0.4.2=py36_0 149 | thinc=6.12.1=py36h4989274_0 150 | tk=8.6.8=hbc83047_0 151 | toolz=0.9.0=py36_0 152 | torch=1.0.0=pypi_0 153 | tornado=5.1.1=py36h7b6447c_0 154 | tqdm=4.29.1=py_0 155 | traitlets=4.3.2=py36_0 156 | ujson=1.35=py36h14c3975_0 157 | urllib3=1.24.1=py36_0 158 | wcwidth=0.1.7=py36_0 159 | webencodings=0.5.1=py36_1 160 | werkzeug=0.14.1=pypi_0 161 | wheel=0.32.3=py36_0 162 | widgetsnbextension=3.4.2=py36_0 163 | wrapt=1.10.11=py36h14c3975_2 164 | xz=5.2.4=h14c3975_4 165 | yaml=0.1.7=h14c3975_1001 166 | zeromq=4.2.5=hf484d3e_1 167 | zlib=1.2.11=h7b6447c_3 168 | -------------------------------------------------------------------------------- /downstream_tasks/ner_eval/ner_detokenize.py: -------------------------------------------------------------------------------- 1 | # Note that this code is adapted from the BioBERT github repo: 2 | # https://github.com/guidoajansen/biobert/tree/87e70a4dfb0dcc1e29ef9d6562f87c4854504e97/biobert/biocodes 3 | 4 | import argparse 5 | import itertools 6 | 7 | parser = argparse.ArgumentParser(description='') 8 | parser.add_argument('--token_test_path', type=str, help='') 9 | parser.add_argument('--label_test_path', type=str, help='') 10 | parser.add_argument('--answer_path', type=str, help='') 11 | parser.add_argument('--output_file', type=str, help='') 12 | parser.add_argument('--tok_to_orig_map_path', type=str, help='') 13 | 14 | args = parser.parse_args() 15 | 16 | def detokenize(golden_path, pred_token_test_path, pred_label_test_path, tok_to_orig_map_path, output_file): 17 | 18 | """convert word-piece BERT-NER results to original words (CoNLL eval format) 19 | 20 | Args: 21 | golden_path: path to golden dataset. ex) NCBI-disease/test.tsv 22 | pred_token_test_path: path to token_test.txt from output folder. ex) output/token_test.txt 23 | pred_label_test_path: path to label_test.txt from output folder. ex) output/label_test.txt 24 | output_file: file where result will write to. ex) output/conll_format.txt 25 | 26 | Outs: 27 | NER_result_conll.txt 28 | """ 29 | # read golden 30 | ans = dict({'toks':[], 'labels':[]}) 31 | with open(golden_path,'r') as in_: 32 | ans_toks, ans_labels = [],[] 33 | for line in in_: 34 | line = line.strip() 35 | if line == '': 36 | if len(ans_toks) == 0: # there must be extra empty lines 37 | continue 38 | #ans['toks'].append('[SEP]') 39 | ans['toks'].append(ans_toks) 40 | ans['labels'].append(ans_labels) 41 | ans_toks =[] 42 | ans_labels=[] 43 | continue 44 | tmp = line.split() 45 | ans_toks.append(tmp[0]) 46 | ans_labels.append(tmp[1]) 47 | if len(ans_toks) > 0: #don't forget the last sentence if there's no final empty line 48 | ans['toks'].append(ans_toks) 49 | ans['labels'].append(ans_labels) 50 | 51 | # read predicted 52 | pred = dict({'toks':[], 'labels':[], 'tok_to_orig':[]}) # dictionary for predicted tokens and labels. 53 | with open(pred_token_test_path,'r') as in_: #'token_test.txt' 54 | pred_toks = [] 55 | for line in in_: 56 | line = line.strip() 57 | if line =='': 58 | pred['toks'].append(pred_toks) 59 | pred_toks = [] 60 | continue 61 | pred_toks.append(line) 62 | if len(pred_toks) > 0: #don't forget the last sentence if there's no final empty line 63 | pred['toks'].append(pred_toks) 64 | 65 | with open(tok_to_orig_map_path,'r') as in_: #'tok_to_orig_map_test.txt' 66 | pred_tok_to_orig = [] 67 | for line in in_: 68 | line = line.strip() 69 | if line =='': 70 | pred['tok_to_orig'].append(pred_tok_to_orig) 71 | pred_tok_to_orig=[] 72 | continue 73 | pred_tok_to_orig.append(int(line)) 74 | if len(pred_tok_to_orig) > 0: #don't forget the last sentence if there's no final empty line 75 | pred['tok_to_orig'].append(pred_tok_to_orig) 76 | 77 | 78 | with open(pred_label_test_path,'r') as in_: 79 | pred_labels = [] 80 | for line in in_: 81 | line = line.strip() 82 | if line in ['[CLS]','[SEP]', 'X']: # replace non-text tokens with O. This will not be evaluated. 83 | pred_labels.append('O') 84 | continue 85 | if line == '': 86 | pred['labels'].append(pred_labels) 87 | pred_labels = [] 88 | continue 89 | pred_labels.append(line) 90 | if len(pred_labels) > 0: #don't forget the last sentence if there's no final empty line 91 | pred['labels'].append(pred_labels) 92 | 93 | 94 | 95 | print(len(pred['toks']), len(pred['labels']), len(ans['labels']), len(ans['toks'])) 96 | 97 | 98 | 99 | if (len(pred['toks']) != len(pred['labels'])): # Sanity check 100 | print("Error! : len(pred['toks']) != len(pred['labels']) : Please report us") 101 | print(len(pred['toks']), len(pred['labels'])) 102 | raise 103 | 104 | if (len(ans['labels']) != len(pred['labels'])): # Sanity check 105 | print(len(ans['labels']), len(pred['labels'])) 106 | print("Error! : len(ans['labels']) != len(bert_pred['labels']) : Please report us") 107 | raise 108 | 109 | bert_pred = dict({'toks':[], 'labels':[]}) 110 | num_too_short = 0 111 | for t, l, tok_to_orig_map, ans_toks in zip(pred['toks'],pred['labels'], pred['tok_to_orig'], ans['toks']): 112 | #remove first and last from each list, which are just buffers 113 | t.pop(0) 114 | t.pop() 115 | l.pop(0) 116 | l.pop() 117 | tok_to_orig_map.pop(0) 118 | tok_to_orig_map.pop() 119 | 120 | if (len(t)!= len(tok_to_orig_map)): 121 | num_too_short += 1 122 | print('Sentence of length %d was truncated' %len(tok_to_orig_map)) 123 | 124 | 125 | bert_pred_toks, bert_pred_labs = [],[] 126 | for ind_into_orig in range(int(tok_to_orig_map[len(t)-1]) + 1): #indexing into t here to deal with issue of truncated tokens 127 | tok_indices = [i for i, x in enumerate(tok_to_orig_map) if x == ind_into_orig] 128 | if len(t) in tok_indices: #skip that token and label because part of the word was truncated during eval 129 | continue 130 | wordpiece_toks = [t[ind][2:] if t[ind][:2] == '##' else t[ind] for ind in tok_indices] 131 | wordpiece_labs = [l[ind] for ind in tok_indices] 132 | bert_pred_toks.append(''.join(wordpiece_toks)) 133 | bert_pred_labs.append(wordpiece_labs[0]) 134 | if len(ans_toks) != len(bert_pred_toks): #if sentence was truncated assume remaining toks were predicted as 0 135 | n_missing_labs = len(ans_toks) - len(bert_pred_toks) 136 | bert_pred_labs.extend(['O'] * n_missing_labs) 137 | bert_pred['toks'].append(bert_pred_toks) 138 | bert_pred['labels'].append(bert_pred_labs) 139 | 140 | 141 | print('Number of sentences that were truncated: %d' %num_too_short) 142 | 143 | 144 | flattened_pred_toks = [item for sublist in bert_pred['toks'] for item in sublist] 145 | flattened_pred_labs = [item for sublist in bert_pred['labels'] for item in sublist] 146 | flattened_ans_labs = [item for sublist in ans['labels'] for item in sublist] 147 | flattened_ans_toks= [item for sublist in ans['toks'] for item in sublist] 148 | 149 | print(len(flattened_pred_toks), len(flattened_pred_labs), len(flattened_ans_labs), len(flattened_ans_toks)) 150 | 151 | 152 | if (len(bert_pred['toks']) != len(bert_pred['labels'])): # Sanity check 153 | print("Error! : len(bert_pred['toks']) != len(bert_pred['labels']) : Please report us") 154 | raise 155 | 156 | if (len(ans['labels']) != len(bert_pred['labels'])): # Sanity check 157 | print("Error! : len(ans['labels']) != len(bert_pred['labels']) : Please report us") 158 | raise 159 | 160 | with open(output_file, 'w') as out_: 161 | for ans_toks, ans_labs, pred_labs in zip(ans['toks'], ans['labels'], bert_pred['labels']): 162 | for ans_t, ans_l, pred_l in zip(ans_toks, ans_labs, pred_labs): 163 | out_.write("%s %s %s\n"%(ans_t, ans_l, pred_l)) 164 | out_.write('\n') 165 | 166 | 167 | detokenize(args.answer_path, args.token_test_path, args.label_test_path, args.tok_to_orig_map_path, args.output_file) 168 | -------------------------------------------------------------------------------- /downstream_tasks/ner_eval/format_for_i2b2_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | # Functions are all taken from Willie Boag's cliner code `documents.py` at 5 | # https://github.com/text-machine-lab/CliNER/blob/5e1599fb2a2209fa0183f623308516816a033d4f/code/notes/documents.py 6 | 7 | def convert_to_i2b2_format(tok_sents, pred_labels, mode=None): 8 | """ 9 | Purpose: Return the given concept label predictions in i2b2 format 10 | 11 | @param tokenized_sents. of tokenized sentences 12 | @param pred_labels. of predicted_labels 13 | @return of i2b2-concept-file-formatted data 14 | """ 15 | 16 | # Return value 17 | retStr = '' 18 | 19 | concept_tuples = tok_labels_to_concepts(tok_sents, pred_labels, mode) 20 | 21 | # For each classification 22 | for classification in concept_tuples: 23 | 24 | # Ensure 'none' classifications are skipped 25 | if classification[0] == 'none': 26 | raise('Classification label "none" should never happen') 27 | 28 | concept = classification[0] 29 | lineno = classification[1] 30 | start = classification[2] 31 | end = classification[3] 32 | 33 | # A list of words (corresponding line from the text file) 34 | text = tok_sents[lineno-1] 35 | 36 | #print "\n" + "-" * 80 37 | #print "classification: ", classification 38 | #print "lineno: ", lineno 39 | #print "start: ", start 40 | #print "end ", end 41 | #print "text: ", text 42 | #print 'len(text): ', len(text) 43 | #print "text[start]: ", text[start] 44 | #print "concept: ", concept 45 | 46 | datum = text[start] 47 | for j in range(start, end): 48 | datum += " " + text[j+1] 49 | datum = datum.lower() 50 | 51 | #print 'datum: ', datum 52 | 53 | # Line:TokenNumber of where the concept starts and ends 54 | idx1 = "%d:%d" % (lineno, start) 55 | idx2 = "%d:%d" % (lineno, end) 56 | 57 | # Classification 58 | label = concept.capitalize() 59 | 60 | 61 | # Print format 62 | retStr += "c=\"%s\" %s %s||t=\"%s\"\n" % (datum, idx1, idx2, label) 63 | 64 | # return formatted data 65 | return retStr.strip() 66 | 67 | 68 | def tok_labels_to_concepts(tokenized_sents, tok_labels, mode): 69 | 70 | #print tok_labels 71 | ''' 72 | for gold,sent in zip(tok_labels, tokenized_sents): 73 | print gold 74 | print sent 75 | print 76 | ''' 77 | 78 | # convert 'B-treatment' into ('B','treatment') and 'O' into ('O',None) 79 | def split_label(label): 80 | if label == 'O': 81 | iob,tag = 'O', None 82 | else: 83 | #print(label) 84 | if 'LOCATION-OTHER' in label: 85 | label = label.replace('LOCATION-OTHER', 'LOCATION_OTHER') 86 | if len(label.split('-')) != 2: 87 | print(label.split('-')) 88 | iob,tag = label.split('-') 89 | return iob, tag 90 | 91 | 92 | 93 | # preprocess predictions to "correct" starting Is into Bs 94 | corrected = [] 95 | for lineno,(toks, labels) in enumerate(zip(tokenized_sents, tok_labels)): 96 | corrected_line = [] 97 | for i in range(len(labels)): 98 | #''' 99 | # is this a candidate for error? 100 | 101 | iob,tag = split_label(labels[i]) 102 | if iob == 'I': 103 | # beginning of line has no previous 104 | if i == 0: 105 | print(mode, 'CORRECTING! A') 106 | if mode == 'gold': 107 | print(toks, labels) 108 | new_label = 'B' + labels[i][1:] 109 | else: 110 | # ensure either its outside OR mismatch type 111 | prev_iob,prev_tag = split_label(labels[i-1]) 112 | if prev_iob == 'O' or prev_tag != tag: 113 | print(mode, 'CORRECTING! B') 114 | new_label = 'B' + labels[i][1:] 115 | else: 116 | new_label = labels[i] 117 | else: 118 | new_label = labels[i] 119 | #''' 120 | corrected_line.append(new_label) 121 | corrected.append( corrected_line ) 122 | 123 | tok_labels = corrected 124 | 125 | concepts = [] 126 | for i,labs in enumerate(tok_labels): 127 | 128 | N = len(labs) 129 | begins = [ j for j,lab in enumerate(labs) if (lab[0] == 'B') ] 130 | 131 | for start in begins: 132 | # "B-test" --> "-test" 133 | label = labs[start][1:] 134 | 135 | # get ending token index 136 | end = start 137 | while (end < N-1) and tok_labels[i][end+1].startswith('I') and tok_labels[i][end+1][1:] == label: 138 | end += 1 139 | 140 | # concept tuple 141 | concept_tuple = (label[1:], i+1, start, end) 142 | concepts.append(concept_tuple) 143 | 144 | ''' 145 | # test it out 146 | for i in range(len(tokenized_sents)): 147 | assert len(tokenized_sents[i]) == len(tok_labels[i]) 148 | for tok,lab in zip(tokenized_sents[i],tok_labels[i]): 149 | if lab != 'O': print '\t', 150 | print lab, tok 151 | print 152 | exit() 153 | ''' 154 | 155 | # test it out 156 | test_tok_labels = tok_concepts_to_labels(tokenized_sents, concepts) 157 | #''' 158 | for lineno,(test,gold,sent) in enumerate(zip(test_tok_labels, tok_labels, tokenized_sents)): 159 | for i,(a,b) in enumerate(zip(test,gold)): 160 | #''' 161 | if not ((a == b)or(a[0]=='B' and b[0]=='I' and a[1:]==b[1:])): 162 | print() 163 | print('lineno: ', lineno) 164 | print() 165 | print('generated: ', test[i-3:i+4]) 166 | print('predicted: ', gold[i-3:i+4]) 167 | print(sent[i-3:i+4]) 168 | print('a[0]: ', a[0]) 169 | print('b[0]: ', b[0]) 170 | print('a[1:]: ', a[1:]) 171 | print('b[1:]: ', b[1:]) 172 | print('a[1:] == b[a:]: ', a[1:] == b[1:]) 173 | print() 174 | #''' 175 | assert (a == b) or (a[0]=='B' and b[0]=='I' and a[1:]==b[1:]) 176 | i += 1 177 | #''' 178 | assert test_tok_labels == tok_labels 179 | 180 | return concepts 181 | 182 | 183 | def tok_concepts_to_labels(tokenized_sents, tok_concepts): 184 | # parallel to tokens 185 | labels = [ ['O' for tok in sent] for sent in tokenized_sents ] 186 | 187 | # fill each concept's tokens appropriately 188 | for concept in tok_concepts: 189 | label,lineno,start_tok,end_tok = concept 190 | labels[lineno-1][start_tok] = 'B-%s' % label 191 | for i in range(start_tok+1,end_tok+1): 192 | labels[lineno-1][i] = 'I-%s' % label 193 | 194 | # test it out 195 | ''' 196 | for i in range(len(tokenized_sents)): 197 | assert len(tokenized_sents[i]) == len(labels[i]) 198 | for tok,lab in zip(tokenized_sents[i],labels[i]): 199 | if lab != 'O': print '\t', 200 | print lab, tok 201 | print 202 | exit() 203 | ''' 204 | 205 | return labels 206 | 207 | 208 | 209 | parser = argparse.ArgumentParser(description='') 210 | parser.add_argument('--results_file', type=str, help='Location of results file in conll format') 211 | parser.add_argument('--output_pred_dir', type=str, help='Location of where to output prediction formatted files') 212 | parser.add_argument('--output_gold_dir', type=str, help='Location of where to output gold formatted files') 213 | 214 | 215 | args = parser.parse_args() 216 | 217 | results_dict = {'tokens': [], 'gold_labels': [], 'predicted_labels': []} 218 | with open(args.results_file, 'r') as results: #+'/NER_result_conll.txt' 219 | toks, gold_labs, pred_labs = [],[],[] 220 | for line in results: 221 | line = line.strip() 222 | if line == '': 223 | results_dict['tokens'].append(toks) 224 | results_dict['gold_labels'].append(gold_labs) 225 | results_dict['predicted_labels'].append(pred_labs) 226 | toks, gold_labs, pred_labs = [],[],[] 227 | else: 228 | tok, gold_lab, pred_lab = line.split() 229 | toks.append(tok) 230 | gold_labs.append(gold_lab) 231 | pred_labs.append(pred_lab) 232 | if len(toks) > 0: 233 | results_dict['tokens'].append(toks) 234 | results_dict['gold_labels'].append(gold_labs) 235 | results_dict['predicted_labels'].append(pred_labs) 236 | 237 | i2b2_format_gold = convert_to_i2b2_format(results_dict['tokens'], results_dict['gold_labels'], 'gold') 238 | i2b2_format_predicted = convert_to_i2b2_format(results_dict['tokens'], results_dict['predicted_labels'], 'pred') 239 | 240 | with open(os.path.join(args.output_pred_dir, "i2b2.con"),'w') as writer: 241 | writer.write(i2b2_format_predicted) 242 | 243 | with open(os.path.join(args.output_gold_dir, "i2b2.con"),'w') as writer: 244 | writer.write(i2b2_format_gold) 245 | 246 | 247 | -------------------------------------------------------------------------------- /lm_pretraining/heuristic_tokenize.py: -------------------------------------------------------------------------------- 1 | # NOTE: this code is taken directly from Willie Boag's mimic-tokenize github repository 2 | # https://github.com/wboag/mimic-tokenize/blob/master/heuristic-tokenize.py commit e953d271bbb4c53aee5cc9a7b8be870a6b007604 3 | # The code was modified in two ways: 4 | # (1) to make the script compatible with Python 3 5 | # (2) to remove the main and discharge_tokenize methods which we don't directly use 6 | 7 | # There are two known issues with this code. We have not yet fixed them because we want to maintain the reproducibility of 8 | # our code for the work that was published. However, anyone wanting to extend this work should make the following changes: 9 | # (1) fix a bug on line #168 where . should be replaced with \. i.e. should be `while re.search('\n\s*%d\.'%n,segment):` 10 | # (2) add else statement (`else: new_segments.append(segments[i])`) to the if statement at line 287 11 | # `if (i == N-1) or is_title(segments[i+1]):` 12 | 13 | 14 | 15 | import sys 16 | import re, nltk 17 | import os 18 | 19 | def strip(s): 20 | return s.strip() 21 | 22 | def is_inline_title(text): 23 | m = re.search('^([a-zA-Z ]+:) ', text) 24 | if not m: 25 | return False 26 | return is_title(m.groups()[0]) 27 | 28 | stopwords = set(['of', 'on', 'or']) 29 | def is_title(text): 30 | if not text.endswith(':'): 31 | return False 32 | text = text[:-1] 33 | 34 | # be a little loose here... can tighten if it causes errors 35 | text = re.sub('(\([^\)]*?\))', '', text) 36 | 37 | # Are all non-stopwords capitalized? 38 | for word in text.split(): 39 | if word in stopwords: continue 40 | if not word[0].isupper(): 41 | return False 42 | 43 | # I noticed this is a common issue (non-title aapears at beginning of line) 44 | if text == 'Disp': 45 | return False 46 | 47 | # optionally: could assert that it is less than 6 tokens 48 | return True 49 | 50 | 51 | 52 | def sent_tokenize_rules(text): 53 | 54 | # long sections are OBVIOUSLY different sentences 55 | text = re.sub('---+', '\n\n-----\n\n', text) 56 | text = re.sub('___+', '\n\n_____\n\n', text) 57 | text = re.sub('\n\n+', '\n\n', text) 58 | 59 | segments = text.split('\n\n') 60 | 61 | # strategy: break down segments and chip away structure until just prose. 62 | # once you have prose, use nltk.sent_tokenize() 63 | 64 | ### Separate section headers ### 65 | new_segments = [] 66 | 67 | # deal with this one edge case (multiple headers per line) up front 68 | m1 = re.match('(Admission Date:) (.*) (Discharge Date:) (.*)', segments[0]) 69 | if m1: 70 | new_segments += list(map(strip,m1.groups())) 71 | segments = segments[1:] 72 | 73 | m2 = re.match('(Date of Birth:) (.*) (Sex:) (.*)' , segments[0]) 74 | if m2: 75 | new_segments += list(map(strip,m2.groups())) 76 | segments = segments[1:] 77 | 78 | for segment in segments: 79 | # find all section headers 80 | possible_headers = re.findall('\n([A-Z][^\n:]+:)', '\n'+segment) 81 | #assert len(possible_headers) < 2, str(possible_headers) 82 | headers = [] 83 | for h in possible_headers: 84 | #print 'cand=[%s]' % h 85 | if is_title(h.strip()): 86 | #print '\tYES=[%s]' % h 87 | headers.append(h.strip()) 88 | 89 | # split text into new segments, delimiting on these headers 90 | for h in headers: 91 | h = h.strip() 92 | 93 | # split this segment into 3 smaller segments 94 | ind = segment.index(h) 95 | prefix = segment[:ind].strip() 96 | rest = segment[ ind+len(h):].strip() 97 | 98 | # add the prefix (potentially empty) 99 | if len(prefix) > 0: 100 | new_segments.append(prefix.strip()) 101 | 102 | # add the header 103 | new_segments.append(h) 104 | 105 | # remove the prefix from processing (very unlikely to be empty) 106 | segment = rest.strip() 107 | 108 | # add the final piece (aka what comes after all headers are processed) 109 | if len(segment) > 0: 110 | new_segments.append(segment.strip()) 111 | 112 | # copy over the new list of segments (further segmented than original segments) 113 | segments = list(new_segments) 114 | new_segments = [] 115 | 116 | 117 | ### Low-hanging fruit: "_____" is a delimiter 118 | for segment in segments: 119 | subsections = segment.split('\n_____\n') 120 | new_segments.append(subsections[0]) 121 | for ss in subsections[1:]: 122 | new_segments.append('_____') 123 | new_segments.append(ss) 124 | 125 | segments = list(new_segments) 126 | new_segments = [] 127 | 128 | 129 | ### Low-hanging fruit: "-----" is a delimiter 130 | for segment in segments: 131 | subsections = segment.split('\n-----\n') 132 | new_segments.append(subsections[0]) 133 | for ss in subsections[1:]: 134 | new_segments.append('-----') 135 | new_segments.append(ss) 136 | 137 | segments = list(new_segments) 138 | new_segments = [] 139 | 140 | ''' 141 | for segment in segments: 142 | print '------------START------------' 143 | print segment 144 | print '-------------END-------------' 145 | print 146 | exit() 147 | ''' 148 | 149 | ### Separate enumerated lists ### 150 | for segment in segments: 151 | if not re.search('\n\s*\d+\.', '\n'+segment): 152 | new_segments.append(segment) 153 | continue 154 | 155 | ''' 156 | print '------------START------------' 157 | print segment 158 | print '-------------END-------------' 159 | print 160 | ''' 161 | 162 | # generalizes in case the list STARTS this section 163 | segment = '\n'+segment 164 | 165 | # determine whether this segment contains a bulleted list (assumes i,i+1,...,n) 166 | start = int(re.search('\n\s*(\d+)\.', '\n'+segment).groups()[0]) 167 | n = start 168 | while re.search('\n\s*%d.'%n,segment): # SHOULD CHANGE TO: while re.search('\n\s*%d\.'%n,segment): #(CHANGED . to \.) 169 | n += 1 170 | n -= 1 171 | 172 | # no bulleted list 173 | if n < 1: 174 | new_segments.append(segment) 175 | continue 176 | 177 | ''' 178 | print '------------START------------' 179 | print segment 180 | print '-------------END-------------' 181 | 182 | print start,n 183 | print 184 | ''' 185 | 186 | # break each list into its own line 187 | # challenge: not clear how to tell when the list ends if more text happens next 188 | for i in range(start,n+1): 189 | matching_text = re.search('(\n\s*\d+\.)',segment).groups()[0] 190 | prefix = segment[:segment.index(matching_text) ].strip() 191 | segment = segment[ segment.index(matching_text):].strip() 192 | if len(prefix)>0: 193 | new_segments.append(prefix) 194 | 195 | if len(segment)>0: 196 | new_segments.append(segment) 197 | 198 | segments = list(new_segments) 199 | new_segments = [] 200 | 201 | ''' 202 | TODO: Big Challenge 203 | 204 | There is so much variation in what makes a list. Intuitively, I can tell it's a 205 | list because it shows repeated structure (often following a header) 206 | 207 | Examples of some lists (with numbers & symptoms changed around to noise) 208 | 209 | Past Medical History: 210 | -- Hyperlipidemia 211 | -- lactose intolerance 212 | -- Hypertension 213 | 214 | 215 | Physical Exam: 216 | Vitals - T 82.2 BP 123/23 HR 73 R 21 75% on 2L NC 217 | General - well appearing male, sitting up in chair in NAD 218 | Neck - supple, JVP elevated to angle of jaw 219 | CV - distant heart sounds, RRR, faint __PHI_43__ murmur at 220 | 221 | 222 | Labs: 223 | __PHI_10__ 12:00PM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 224 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 225 | __PHI_14__ 04:54AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 226 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 227 | __PHI_23__ 03:33AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 228 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 229 | __PHI_109__ 03:06AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 230 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 231 | __PHI_1__ 05:09AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 232 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 233 | __PHI_26__ 04:53AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 234 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 235 | __PHI_301__ 05:30AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8* 236 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888 237 | 238 | 239 | Medications on Admission: 240 | Allopurinol 100 mg DAILY 241 | Aspirin 250 mg DAILY 242 | Atorvastatin 10 mg DAILY 243 | Glimepiride 1 mg once a week. 244 | Hexavitamin DAILY 245 | Lasix 50mg M-W-F; 60mg T-Th-Sat-Sun 246 | Metoprolol 12.5mg TID 247 | Prilosec OTC 20 mg once a day 248 | Verapamil 120 mg SR DAILY 249 | ''' 250 | 251 | ### Remove lines with inline titles from larger segments (clearly nonprose) 252 | for segment in segments: 253 | ''' 254 | With: __PHI_6__, MD __PHI_5__ 255 | Building: De __PHI_45__ Building (__PHI_32__ Complex) __PHI_87__ 256 | Campus: WEST 257 | ''' 258 | 259 | lines = segment.split('\n') 260 | 261 | buf = [] 262 | for i in range(len(lines)): 263 | if is_inline_title(lines[i]): 264 | if len(buf) > 0: 265 | new_segments.append('\n'.join(buf)) 266 | buf = [] 267 | buf.append(lines[i]) 268 | if len(buf) > 0: 269 | new_segments.append('\n'.join(buf)) 270 | 271 | segments = list(new_segments) 272 | new_segments = [] 273 | 274 | 275 | # Going to put one-liner answers with their sections 276 | # (aka A A' B B' C D D' --> AA' BB' C DD' ) 277 | N = len(segments) 278 | for i in range(len(segments)): 279 | # avoid segfaults 280 | if i==0: 281 | new_segments.append(segments[i]) 282 | continue 283 | 284 | if segments[i].count('\n') == 0 and \ 285 | is_title(segments[i-1]) and \ 286 | not is_title(segments[i ]): 287 | if (i == N-1) or is_title(segments[i+1]): 288 | new_segments = new_segments[:-1] 289 | new_segments.append(segments[i-1] + ' ' + segments[i]) 290 | #else: new_segments.append(segments[i]) #ADD TO FIX BUG 291 | # currently If the code sees a segment that doesn't have any new lines and the prior line is a title 292 | # *but* it is not the last segment and the next segment is not a title then that segment is just dropped 293 | # so lists that have a title header will lose their first entry 294 | else: 295 | new_segments.append(segments[i]) 296 | 297 | segments = list(new_segments) 298 | new_segments = [] 299 | 300 | ''' 301 | Should do some kind of regex to find "TEST: value" in segments? 302 | 303 | Indication: Source of embolism. 304 | BP (mm Hg): 145/89 305 | HR (bpm): 80 306 | 307 | Note: I made a temporary hack that fixes this particular problem. 308 | We'll see how it shakes out 309 | ''' 310 | 311 | 312 | ''' 313 | Separate ALL CAPS lines (Warning... is there ever prose that can be all caps?) 314 | ''' 315 | 316 | 317 | 318 | 319 | return segments 320 | 321 | 322 | 323 | 324 | -------------------------------------------------------------------------------- /downstream_tasks/ner_eval/conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://www.clips.uantwerpen.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://www.clips.uantwerpen.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 86 | elsif ($nbrOfFeatures != $#features and @features != 0) { 87 | printf STDERR "unexpected number of features: %d (%d)\n", 88 | $#features+1,$nbrOfFeatures+1; 89 | exit(1); 90 | } 91 | if (@features == 0 or 92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 93 | if (@features < 2) { 94 | die "conlleval: unexpected number of features in line $line\n"; 95 | } 96 | if ($raw) { 97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 99 | if ($features[$#features] ne "O") { 100 | $features[$#features] = "B-$features[$#features]"; 101 | } 102 | if ($features[$#features-1] ne "O") { 103 | $features[$#features-1] = "B-$features[$#features-1]"; 104 | } 105 | } 106 | # 20040126 ET code which allows hyphens in the types 107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 108 | $guessed = $1; 109 | $guessedType = $2; 110 | } else { 111 | $guessed = $features[$#features]; 112 | $guessedType = ""; 113 | } 114 | pop(@features); 115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 116 | $correct = $1; 117 | $correctType = $2; 118 | } else { 119 | $correct = $features[$#features]; 120 | $correctType = ""; 121 | } 122 | pop(@features); 123 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 124 | # ($correct,$correctType) = split(/-/,pop(@features)); 125 | $guessedType = $guessedType ? $guessedType : ""; 126 | $correctType = $correctType ? $correctType : ""; 127 | $firstItem = shift(@features); 128 | 129 | # 1999-06-26 sentence breaks should always be counted as out of chunk 130 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 131 | 132 | if ($inCorrect) { 133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 135 | $lastGuessedType eq $lastCorrectType) { 136 | $inCorrect=$false; 137 | $correctChunk++; 138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 139 | $correctChunk{$lastCorrectType}+1 : 1; 140 | } elsif ( 141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 143 | $guessedType ne $correctType ) { 144 | $inCorrect=$false; 145 | } 146 | } 147 | 148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 150 | $guessedType eq $correctType) { $inCorrect = $true; } 151 | 152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 153 | $foundCorrect++; 154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 155 | $foundCorrect{$correctType}+1 : 1; 156 | } 157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 158 | $foundGuessed++; 159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 160 | $foundGuessed{$guessedType}+1 : 1; 161 | } 162 | if ( $firstItem ne $boundary ) { 163 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 164 | $correctTags++; 165 | } 166 | $tokenCounter++; 167 | } 168 | 169 | $lastGuessed = $guessed; 170 | $lastCorrect = $correct; 171 | $lastGuessedType = $guessedType; 172 | $lastCorrectType = $correctType; 173 | } 174 | if ($inCorrect) { 175 | $correctChunk++; 176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 177 | $correctChunk{$lastCorrectType}+1 : 1; 178 | } 179 | 180 | if (not $latex) { 181 | # compute overall precision, recall and FB1 (default values are 0.0) 182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 184 | $FB1 = 2*$precision*$recall/($precision+$recall) 185 | if ($precision+$recall > 0); 186 | 187 | # print overall performance 188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 190 | if ($tokenCounter>0) { 191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 192 | printf "precision: %6.2f%%; ",$precision; 193 | printf "recall: %6.2f%%; ",$recall; 194 | printf "FB1: %6.2f\n",$FB1; 195 | } 196 | } 197 | 198 | # sort chunk type names 199 | undef($lastType); 200 | @sortedTypes = (); 201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 202 | if (not($lastType) or $lastType ne $i) { 203 | push(@sortedTypes,($i)); 204 | } 205 | $lastType = $i; 206 | } 207 | # print performance per chunk type 208 | if (not $latex) { 209 | for $i (@sortedTypes) { 210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 213 | if (not($foundCorrect{$i})) { $recall = 0.0; } 214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 217 | printf "%17s: ",$i; 218 | printf "precision: %6.2f%%; ",$precision; 219 | printf "recall: %6.2f%%; ",$recall; 220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 221 | } 222 | } else { 223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 224 | for $i (@sortedTypes) { 225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 226 | if (not($foundGuessed{$i})) { $precision = 0.0; } 227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 228 | if (not($foundCorrect{$i})) { $recall = 0.0; } 229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 233 | $i,$precision,$recall,$FB1; 234 | } 235 | print "\\hline\n"; 236 | $precision = 0.0; 237 | $recall = 0; 238 | $FB1 = 0.0; 239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 241 | $FB1 = 2*$precision*$recall/($precision+$recall) 242 | if ($precision+$recall > 0); 243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 244 | $precision,$recall,$FB1; 245 | } 246 | 247 | exit 0; 248 | 249 | # endOfChunk: checks if a chunk ended between the previous and current word 250 | # arguments: previous and current chunk tags, previous and current types 251 | # note: this code is capable of handling other chunk representations 252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 254 | 255 | sub endOfChunk { 256 | my $prevTag = shift(@_); 257 | my $tag = shift(@_); 258 | my $prevType = shift(@_); 259 | my $type = shift(@_); 260 | my $chunkEnd = $false; 261 | 262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 266 | 267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 271 | 272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 273 | $chunkEnd = $true; 274 | } 275 | 276 | # corrected 1998-12-22: these chunks are assumed to have length 1 277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 279 | 280 | return($chunkEnd); 281 | } 282 | 283 | # startOfChunk: checks if a chunk started between the previous and current word 284 | # arguments: previous and current chunk tags, previous and current types 285 | # note: this code is capable of handling other chunk representations 286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 288 | 289 | sub startOfChunk { 290 | my $prevTag = shift(@_); 291 | my $tag = shift(@_); 292 | my $prevType = shift(@_); 293 | my $type = shift(@_); 294 | my $chunkStart = $false; 295 | 296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 300 | 301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 305 | 306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 307 | $chunkStart = $true; 308 | } 309 | 310 | # corrected 1998-12-22: these chunks are assumed to have length 1 311 | if ( $tag eq "[" ) { $chunkStart = $true; } 312 | if ( $tag eq "]" ) { $chunkStart = $true; } 313 | 314 | return($chunkStart); 315 | } 316 | -------------------------------------------------------------------------------- /lm_pretraining/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tokenization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("input_file", None, 31 | "Input raw text file (or comma-separated list of files).") 32 | 33 | flags.DEFINE_string( 34 | "output_file", None, 35 | "Output TF example file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("vocab_file", None, 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_bool( 41 | "do_lower_case", True, 42 | "Whether to lower case the input text. Should be True for uncased " 43 | "models and False for cased models.") 44 | 45 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 46 | 47 | flags.DEFINE_integer("max_predictions_per_seq", 20, 48 | "Maximum number of masked LM predictions per sequence.") 49 | 50 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 51 | 52 | flags.DEFINE_integer( 53 | "dupe_factor", 10, 54 | "Number of times to duplicate the input data (with different masks).") 55 | 56 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 57 | 58 | flags.DEFINE_float( 59 | "short_seq_prob", 0.1, 60 | "Probability of creating sequences which are shorter than the " 61 | "maximum length.") 62 | 63 | 64 | class TrainingInstance(object): 65 | """A single training instance (sentence pair).""" 66 | 67 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 68 | is_random_next): 69 | self.tokens = tokens 70 | self.segment_ids = segment_ids 71 | self.is_random_next = is_random_next 72 | self.masked_lm_positions = masked_lm_positions 73 | self.masked_lm_labels = masked_lm_labels 74 | 75 | def __str__(self): 76 | s = "" 77 | s += "tokens: %s\n" % (" ".join( 78 | [tokenization.printable_text(x) for x in self.tokens])) 79 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 80 | s += "is_random_next: %s\n" % self.is_random_next 81 | s += "masked_lm_positions: %s\n" % (" ".join( 82 | [str(x) for x in self.masked_lm_positions])) 83 | s += "masked_lm_labels: %s\n" % (" ".join( 84 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 85 | s += "\n" 86 | return s 87 | 88 | def __repr__(self): 89 | return self.__str__() 90 | 91 | 92 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 93 | max_predictions_per_seq, output_files): 94 | """Create TF example files from `TrainingInstance`s.""" 95 | writers = [] 96 | for output_file in output_files: 97 | writers.append(tf.python_io.TFRecordWriter(output_file)) 98 | 99 | writer_index = 0 100 | 101 | total_written = 0 102 | for (inst_index, instance) in enumerate(instances): 103 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 104 | input_mask = [1] * len(input_ids) 105 | segment_ids = list(instance.segment_ids) 106 | assert len(input_ids) <= max_seq_length 107 | 108 | while len(input_ids) < max_seq_length: 109 | input_ids.append(0) 110 | input_mask.append(0) 111 | segment_ids.append(0) 112 | 113 | assert len(input_ids) == max_seq_length 114 | assert len(input_mask) == max_seq_length 115 | assert len(segment_ids) == max_seq_length 116 | 117 | masked_lm_positions = list(instance.masked_lm_positions) 118 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 119 | masked_lm_weights = [1.0] * len(masked_lm_ids) 120 | 121 | while len(masked_lm_positions) < max_predictions_per_seq: 122 | masked_lm_positions.append(0) 123 | masked_lm_ids.append(0) 124 | masked_lm_weights.append(0.0) 125 | 126 | next_sentence_label = 1 if instance.is_random_next else 0 127 | 128 | features = collections.OrderedDict() 129 | features["input_ids"] = create_int_feature(input_ids) 130 | features["input_mask"] = create_int_feature(input_mask) 131 | features["segment_ids"] = create_int_feature(segment_ids) 132 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 133 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 134 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 135 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 136 | 137 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 138 | 139 | writers[writer_index].write(tf_example.SerializeToString()) 140 | writer_index = (writer_index + 1) % len(writers) 141 | 142 | total_written += 1 143 | 144 | if inst_index < 20: 145 | tf.logging.info("*** Example ***") 146 | tf.logging.info("tokens: %s" % " ".join( 147 | [tokenization.printable_text(x) for x in instance.tokens])) 148 | 149 | for feature_name in features.keys(): 150 | feature = features[feature_name] 151 | values = [] 152 | if feature.int64_list.value: 153 | values = feature.int64_list.value 154 | elif feature.float_list.value: 155 | values = feature.float_list.value 156 | tf.logging.info( 157 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 158 | 159 | for writer in writers: 160 | writer.close() 161 | 162 | tf.logging.info("Wrote %d total instances", total_written) 163 | 164 | 165 | def create_int_feature(values): 166 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 167 | return feature 168 | 169 | 170 | def create_float_feature(values): 171 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 172 | return feature 173 | 174 | 175 | def create_training_instances(input_files, tokenizer, max_seq_length, 176 | dupe_factor, short_seq_prob, masked_lm_prob, 177 | max_predictions_per_seq, rng): 178 | """Create `TrainingInstance`s from raw text.""" 179 | all_documents = [[]] 180 | 181 | # Input file format: 182 | # (1) One sentence per line. These should ideally be actual sentences, not 183 | # entire paragraphs or arbitrary spans of text. (Because we use the 184 | # sentence boundaries for the "next sentence prediction" task). 185 | # (2) Blank lines between documents. Document boundaries are needed so 186 | # that the "next sentence prediction" task doesn't span between documents. 187 | for input_file in input_files: 188 | with tf.gfile.GFile(input_file, "r") as reader: 189 | while True: 190 | line = tokenization.convert_to_unicode(reader.readline()) 191 | if not line: 192 | break 193 | line = line.strip() 194 | 195 | # Empty lines are used as document delimiters 196 | if not line: 197 | all_documents.append([]) 198 | tokens = tokenizer.tokenize(line) 199 | if tokens: 200 | all_documents[-1].append(tokens) 201 | 202 | # Remove empty documents 203 | all_documents = [x for x in all_documents if x] 204 | rng.shuffle(all_documents) 205 | 206 | vocab_words = list(tokenizer.vocab.keys()) 207 | instances = [] 208 | for _ in range(dupe_factor): 209 | for document_index in range(len(all_documents)): 210 | instances.extend( 211 | create_instances_from_document( 212 | all_documents, document_index, max_seq_length, short_seq_prob, 213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 214 | 215 | rng.shuffle(instances) 216 | return instances 217 | 218 | 219 | def create_instances_from_document( 220 | all_documents, document_index, max_seq_length, short_seq_prob, 221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 222 | """Creates `TrainingInstance`s for a single document.""" 223 | document = all_documents[document_index] 224 | 225 | # Account for [CLS], [SEP], [SEP] 226 | max_num_tokens = max_seq_length - 3 227 | 228 | # We *usually* want to fill up the entire sequence since we are padding 229 | # to `max_seq_length` anyways, so short sequences are generally wasted 230 | # computation. However, we *sometimes* 231 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 232 | # sequences to minimize the mismatch between pre-training and fine-tuning. 233 | # The `target_seq_length` is just a rough target however, whereas 234 | # `max_seq_length` is a hard limit. 235 | target_seq_length = max_num_tokens 236 | if rng.random() < short_seq_prob: 237 | target_seq_length = rng.randint(2, max_num_tokens) 238 | 239 | # We DON'T just concatenate all of the tokens from a document into a long 240 | # sequence and choose an arbitrary split point because this would make the 241 | # next sentence prediction task too easy. Instead, we split the input into 242 | # segments "A" and "B" based on the actual "sentences" provided by the user 243 | # input. 244 | instances = [] 245 | current_chunk = [] 246 | current_length = 0 247 | i = 0 248 | while i < len(document): 249 | segment = document[i] 250 | current_chunk.append(segment) 251 | current_length += len(segment) 252 | if i == len(document) - 1 or current_length >= target_seq_length: 253 | if current_chunk: 254 | # `a_end` is how many segments from `current_chunk` go into the `A` 255 | # (first) sentence. 256 | a_end = 1 257 | if len(current_chunk) >= 2: 258 | a_end = rng.randint(1, len(current_chunk) - 1) 259 | 260 | tokens_a = [] 261 | for j in range(a_end): 262 | tokens_a.extend(current_chunk[j]) 263 | 264 | tokens_b = [] 265 | # Random next 266 | is_random_next = False 267 | if len(current_chunk) == 1 or rng.random() < 0.5: 268 | is_random_next = True 269 | target_b_length = target_seq_length - len(tokens_a) 270 | 271 | # This should rarely go for more than one iteration for large 272 | # corpora. However, just to be careful, we try to make sure that 273 | # the random document is not the same as the document 274 | # we're processing. 275 | for _ in range(10): 276 | random_document_index = rng.randint(0, len(all_documents) - 1) 277 | if random_document_index != document_index: 278 | break 279 | 280 | random_document = all_documents[random_document_index] 281 | random_start = rng.randint(0, len(random_document) - 1) 282 | for j in range(random_start, len(random_document)): 283 | tokens_b.extend(random_document[j]) 284 | if len(tokens_b) >= target_b_length: 285 | break 286 | # We didn't actually use these segments so we "put them back" so 287 | # they don't go to waste. 288 | num_unused_segments = len(current_chunk) - a_end 289 | i -= num_unused_segments 290 | # Actual next 291 | else: 292 | is_random_next = False 293 | for j in range(a_end, len(current_chunk)): 294 | tokens_b.extend(current_chunk[j]) 295 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 296 | 297 | assert len(tokens_a) >= 1 298 | assert len(tokens_b) >= 1 299 | 300 | tokens = [] 301 | segment_ids = [] 302 | tokens.append("[CLS]") 303 | segment_ids.append(0) 304 | for token in tokens_a: 305 | tokens.append(token) 306 | segment_ids.append(0) 307 | 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | for token in tokens_b: 312 | tokens.append(token) 313 | segment_ids.append(1) 314 | tokens.append("[SEP]") 315 | segment_ids.append(1) 316 | 317 | (tokens, masked_lm_positions, 318 | masked_lm_labels) = create_masked_lm_predictions( 319 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 320 | instance = TrainingInstance( 321 | tokens=tokens, 322 | segment_ids=segment_ids, 323 | is_random_next=is_random_next, 324 | masked_lm_positions=masked_lm_positions, 325 | masked_lm_labels=masked_lm_labels) 326 | instances.append(instance) 327 | current_chunk = [] 328 | current_length = 0 329 | i += 1 330 | 331 | return instances 332 | 333 | 334 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 335 | ["index", "label"]) 336 | 337 | 338 | def create_masked_lm_predictions(tokens, masked_lm_prob, 339 | max_predictions_per_seq, vocab_words, rng): 340 | """Creates the predictions for the masked LM objective.""" 341 | 342 | cand_indexes = [] 343 | for (i, token) in enumerate(tokens): 344 | if token == "[CLS]" or token == "[SEP]": 345 | continue 346 | cand_indexes.append(i) 347 | 348 | rng.shuffle(cand_indexes) 349 | 350 | output_tokens = list(tokens) 351 | 352 | num_to_predict = min(max_predictions_per_seq, 353 | max(1, int(round(len(tokens) * masked_lm_prob)))) 354 | 355 | masked_lms = [] 356 | covered_indexes = set() 357 | for index in cand_indexes: 358 | if len(masked_lms) >= num_to_predict: 359 | break 360 | if index in covered_indexes: 361 | continue 362 | covered_indexes.add(index) 363 | 364 | masked_token = None 365 | # 80% of the time, replace with [MASK] 366 | if rng.random() < 0.8: 367 | masked_token = "[MASK]" 368 | else: 369 | # 10% of the time, keep original 370 | if rng.random() < 0.5: 371 | masked_token = tokens[index] 372 | # 10% of the time, replace with random word 373 | else: 374 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 375 | 376 | output_tokens[index] = masked_token 377 | 378 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 379 | 380 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 381 | 382 | masked_lm_positions = [] 383 | masked_lm_labels = [] 384 | for p in masked_lms: 385 | masked_lm_positions.append(p.index) 386 | masked_lm_labels.append(p.label) 387 | 388 | return (output_tokens, masked_lm_positions, masked_lm_labels) 389 | 390 | 391 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 392 | """Truncates a pair of sequences to a maximum sequence length.""" 393 | while True: 394 | total_length = len(tokens_a) + len(tokens_b) 395 | if total_length <= max_num_tokens: 396 | break 397 | 398 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 399 | assert len(trunc_tokens) >= 1 400 | 401 | # We want to sometimes truncate from the front and sometimes from the 402 | # back to add more randomness and avoid biases. 403 | if rng.random() < 0.5: 404 | del trunc_tokens[0] 405 | else: 406 | trunc_tokens.pop() 407 | 408 | 409 | def main(_): 410 | tf.logging.set_verbosity(tf.logging.INFO) 411 | 412 | tokenizer = tokenization.FullTokenizer( 413 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 414 | 415 | input_files = [] 416 | for input_pattern in FLAGS.input_file.split(","): 417 | input_files.extend(tf.gfile.Glob(input_pattern)) 418 | 419 | tf.logging.info("*** Reading from input files ***") 420 | for input_file in input_files: 421 | tf.logging.info(" %s", input_file) 422 | 423 | rng = random.Random(FLAGS.random_seed) 424 | instances = create_training_instances( 425 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 426 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 427 | rng) 428 | 429 | output_files = FLAGS.output_file.split(",") 430 | tf.logging.info("*** Writing to output files ***") 431 | for output_file in output_files: 432 | tf.logging.info(" %s", output_file) 433 | 434 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 435 | FLAGS.max_predictions_per_seq, output_files) 436 | 437 | 438 | if __name__ == "__main__": 439 | flags.mark_flag_as_required("input_file") 440 | flags.mark_flag_as_required("output_file") 441 | flags.mark_flag_as_required("vocab_file") 442 | tf.app.run() 443 | -------------------------------------------------------------------------------- /lm_pretraining/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /downstream_tasks/i2b2_preprocessing/i2b2_2012/Reformat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 20, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, xml, xml.etree.ElementTree as ET, numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "*******This is a brief decsription of the i2b2 sample data*********\r\n", 22 | "1. Data files:\r\n", 23 | " The group of 3 files with file extentions: txt, extent and tlink\r\n", 24 | "are in the format that are consistent with the previous i2b2 challenges.\r\n", 25 | " In addition, an xml file is provided for each record. The XML \r\n", 26 | "file is browsable using the MAE tool found in the directory.\r\n", 27 | "\r\n", 28 | "2. Usage of MAE tool\r\n", 29 | " - Download and unzip the MAE.zip \r\n", 30 | " - Run the MAE_v0.9.3.jar (please make sure that you have Java installed)\r\n", 31 | " - In menu bar, select File -> Load DTD, and in the popup window,\r\n", 32 | " navigate to the sample folder in the unzipped MAE folder, load the\r\n", 33 | " TemporalAnnotation.dtd\r\n", 34 | " - Then, select File -> Load File, and load one of the records, e.g. \r\n", 35 | " 357.xml\r\n", 36 | " - You will be able to browse the annotations by clicking on different\r\n", 37 | " tabs.\r\n", 38 | "\r\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "cat readme.txt" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 54, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "START_CDATA = \"\"\n", 54 | "\n", 55 | "TAGS = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'event', 'FAMILY_HIST']\n", 56 | "\n", 57 | "def read_xml_file(xml_path, event_tag_type='ALL_CHILDREN', match_text=True):\n", 58 | " with open(xml_path, mode='r') as f:\n", 59 | " lines = f.readlines()\n", 60 | " text, in_text = [], False\n", 61 | " for i, l in enumerate(lines):\n", 62 | " if START_CDATA in l:\n", 63 | " text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))\n", 64 | " in_text = True\n", 65 | " elif END_CDATA in l:\n", 66 | " text.append(list(l[:l.find(END_CDATA)]))\n", 67 | " break\n", 68 | " elif in_text:\n", 69 | "# if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:\n", 70 | "# print(\"Adjusting known error\")\n", 71 | "# l = l[:9] + ' ' * 4 + l[9:]\n", 72 | "# # elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:\n", 73 | "# # print(\"Adjusting known error\")\n", 74 | "# # l = l.replace('&', 'and')\n", 75 | " text.append(list(l))\n", 76 | " \n", 77 | " pos_transformer = {}\n", 78 | " \n", 79 | " linear_pos = 1\n", 80 | " for line, sentence in enumerate(text):\n", 81 | " for char_pos, char in enumerate(sentence):\n", 82 | " pos_transformer[linear_pos] = (line, char_pos)\n", 83 | " linear_pos += 1\n", 84 | " \n", 85 | " try: xml_parsed = ET.parse(xml_path)\n", 86 | " except:\n", 87 | " print(xml_path)\n", 88 | " raise\n", 89 | " \n", 90 | " tag_containers = xml_parsed.findall('TAGS')\n", 91 | " assert len(tag_containers) == 1, \"Found multiple tag sets!\"\n", 92 | " tag_container = tag_containers[0]\n", 93 | " \n", 94 | "# event_tags = tag_container.getchildren() if event_tag_type == 'ALL_CHILDREN' else tag_container.findall('event')\n", 95 | " event_tags = tag_container.findall('EVENT')\n", 96 | " event_labels = [['O'] * len(sentence) for sentence in text]\n", 97 | " for event_tag in event_tags:\n", 98 | " base_label = event_tag.attrib['type']\n", 99 | " start_pos, end_pos, event_text = event_tag.attrib['start'], event_tag.attrib['end'], event_tag.attrib['text']\n", 100 | " start_pos, end_pos = int(start_pos)+1, int(end_pos)\n", 101 | " event_text = ' '.join(event_text.split())\n", 102 | "# if event_text == \"0808 O’neil’s Court\":\n", 103 | "# print(\"Adjusting known error\")\n", 104 | "# end_pos -= 4\n", 105 | "# if event_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):\n", 106 | "# print(\"Adjusting known error\")\n", 107 | "# event_text = 'Johnson & Johnson'\n", 108 | " \n", 109 | "\n", 110 | " (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]\n", 111 | " \n", 112 | " obs_text = []\n", 113 | " for line in range(start_line, end_line+1):\n", 114 | " t = text[line]\n", 115 | " s = start_char if line == start_line else 0\n", 116 | " e = end_char if line == end_line else len(t)\n", 117 | " obs_text.append(''.join(t[s:e+1]).strip())\n", 118 | " obs_text = ' '.join(obs_text)\n", 119 | " obs_text = ' '.join(obs_text.split())\n", 120 | " \n", 121 | " if ''' in obs_text and ''' not in event_text: event_text = event_text.replace(\"'\", \"'\")\n", 122 | " if '"' in obs_text and '"' not in event_text: event_text = event_text.replace('\"', '"')\n", 123 | " \n", 124 | " if match_text: assert obs_text == event_text, (\n", 125 | " (\"Texts don't match! %s v %s\" % (event_text, obs_text)) + '\\n' + str((\n", 126 | " start_pos, end_pos, line, s, e, t, xml_path\n", 127 | " ))\n", 128 | " )\n", 129 | " \n", 130 | " if base_label.strip() == '': continue\n", 131 | " \n", 132 | " event_labels[end_line][end_char] = 'I-%s' % base_label\n", 133 | " event_labels[start_line][start_char] = 'B-%s' % base_label\n", 134 | " \n", 135 | " for line in range(start_line, end_line+1):\n", 136 | " t = text[line]\n", 137 | " s = start_char+1 if line == start_line else 0\n", 138 | " e = end_char-1 if line == end_line else len(t)-1\n", 139 | " for i in range(s, e+1): event_labels[line][i] = 'I-%s' % base_label\n", 140 | "\n", 141 | " return text, event_labels\n", 142 | " \n", 143 | "def merge_into_words(text_by_char, all_labels_by_char):\n", 144 | " assert len(text_by_char) == len(all_labels_by_char), \"Incorrect # of sentences!\"\n", 145 | " \n", 146 | " N = len(text_by_char)\n", 147 | " \n", 148 | " text_by_word, all_labels_by_word = [], []\n", 149 | " \n", 150 | " for sentence_num in range(N):\n", 151 | " sentence_by_char = text_by_char[sentence_num]\n", 152 | " labels_by_char = all_labels_by_char[sentence_num]\n", 153 | " \n", 154 | " assert len(sentence_by_char) == len(labels_by_char), \"Incorrect # of chars in sentence!\"\n", 155 | " S = len(sentence_by_char)\n", 156 | " \n", 157 | " if labels_by_char == (['O'] * len(sentence_by_char)):\n", 158 | " sentence_by_word = ''.join(sentence_by_char).split()\n", 159 | " labels_by_word = ['O'] * len(sentence_by_word)\n", 160 | " else: \n", 161 | " sentence_by_word, labels_by_word = [], []\n", 162 | " text_chunks, labels_chunks = [], []\n", 163 | " s = 0\n", 164 | " for i in range(S):\n", 165 | " if i == S-1:\n", 166 | " text_chunks.append(sentence_by_char[s:])\n", 167 | " labels_chunks.append(labels_by_char[s:])\n", 168 | " elif labels_by_char[i] == 'O': continue\n", 169 | " else:\n", 170 | " if i > 0 and labels_by_char[i-1] == 'O':\n", 171 | " text_chunks.append(sentence_by_char[s:i])\n", 172 | " labels_chunks.append(labels_by_char[s:i])\n", 173 | " s = i\n", 174 | " if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:\n", 175 | " text_chunks.append(sentence_by_char[s:i+1])\n", 176 | " labels_chunks.append(labels_by_char[s:i+1])\n", 177 | " s = i+1\n", 178 | " \n", 179 | " for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):\n", 180 | " assert len(text_chunk) == len(labels_chunk), \"Bad Chunking (len)\"\n", 181 | " assert len(text_chunk) > 0, \"Bad chunking (len 0)\" + str(text_chunks) + str(labels_chunks)\n", 182 | " \n", 183 | " labels_set = set(labels_chunk)\n", 184 | " assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (\n", 185 | " (\"Bad chunking (contents) %s\" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)\n", 186 | " )\n", 187 | " \n", 188 | " text_chunk_by_word = ''.join(text_chunk).split()\n", 189 | " W = len(text_chunk_by_word)\n", 190 | " if W == 0: \n", 191 | "# assert labels_set == set(['O']), \"0-word chunking and non-0 label!\" + str(\n", 192 | "# text_chunks) + str(labels_chunks\n", 193 | "# )\n", 194 | " continue\n", 195 | " \n", 196 | " if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W\n", 197 | " elif W == 1: labels_chunk_by_word = [labels_chunk[0]]\n", 198 | " elif W == 2: labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]\n", 199 | " else: labels_chunk_by_word = [\n", 200 | " labels_chunk[0]\n", 201 | " ] + [labels_chunk[1]] * (W - 2) + [\n", 202 | " labels_chunk[-1]\n", 203 | " ]\n", 204 | " \n", 205 | " sentence_by_word.extend(text_chunk_by_word)\n", 206 | " labels_by_word.extend(labels_chunk_by_word)\n", 207 | "\n", 208 | " assert len(sentence_by_word) == len(labels_by_word), \"Incorrect # of words in sentence!\" \n", 209 | " \n", 210 | " if len(sentence_by_word) == 0: continue\n", 211 | " \n", 212 | " text_by_word.append(sentence_by_word)\n", 213 | " all_labels_by_word.append(labels_by_word)\n", 214 | " return text_by_word, all_labels_by_word\n", 215 | "\n", 216 | "def reprocess_event_labels(folders, base_path='.', event_tag_type='event', match_text=True, dev_set_size=None):\n", 217 | " all_texts_by_patient, all_labels_by_patient = {}, {}\n", 218 | "\n", 219 | " for folder in folders:\n", 220 | " folder_dir = os.path.join(base_path, folder)\n", 221 | " xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]\n", 222 | " for xml_filename in xml_filenames:\n", 223 | " patient_num = int(xml_filename[:-4])\n", 224 | " xml_filepath = os.path.join(folder_dir, xml_filename)\n", 225 | " \n", 226 | " text_by_char, labels_by_char = read_xml_file(\n", 227 | " xml_filepath,\n", 228 | " event_tag_type=event_tag_type,\n", 229 | " match_text=match_text\n", 230 | " )\n", 231 | " text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)\n", 232 | " \n", 233 | " if patient_num not in all_texts_by_patient:\n", 234 | " all_texts_by_patient[patient_num] = []\n", 235 | " all_labels_by_patient[patient_num] = []\n", 236 | " \n", 237 | " all_texts_by_patient[patient_num].extend(text_by_word)\n", 238 | " all_labels_by_patient[patient_num].extend(labels_by_word)\n", 239 | " \n", 240 | " patients = set(all_texts_by_patient.keys())\n", 241 | " \n", 242 | " if dev_set_size is None: train_patients, dev_patients = list(patients), []\n", 243 | " else:\n", 244 | " N_train = int(len(patients) * (1-dev_set_size))\n", 245 | " patients_random = np.random.permutation(list(patients))\n", 246 | " train_patients = list(patients_random[:N_train])\n", 247 | " dev_patients = list(patients_random[N_train:])\n", 248 | " \n", 249 | " train_texts, train_labels = [], []\n", 250 | " dev_texts, dev_labels = [], []\n", 251 | " \n", 252 | " for patient_num in train_patients:\n", 253 | " train_texts.extend(all_texts_by_patient[patient_num])\n", 254 | " train_labels.extend(all_labels_by_patient[patient_num])\n", 255 | "\n", 256 | " for patient_num in dev_patients:\n", 257 | " dev_texts.extend(all_texts_by_patient[patient_num])\n", 258 | " dev_labels.extend(all_labels_by_patient[patient_num])\n", 259 | "\n", 260 | "\n", 261 | " train_out_text_by_sentence = []\n", 262 | " for text, labels in zip(train_texts, train_labels):\n", 263 | " train_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n", 264 | " dev_out_text_by_sentence = []\n", 265 | " for text, labels in zip(dev_texts, dev_labels):\n", 266 | " dev_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n", 267 | "\n", 268 | " return '\\n\\n'.join(train_out_text_by_sentence), '\\n\\n'.join(dev_out_text_by_sentence)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 55, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "final_train_text, final_dev_text = reprocess_event_labels(\n", 278 | " ['2012-07-15.original-annotation.release'], dev_set_size=0.1, match_text=True\n", 279 | ")" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 56, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "test_text, _ = reprocess_event_labels(\n", 289 | " ['2012-08-08.test-data.event-timex-groundtruth/xml'], match_text=False, dev_set_size=None\n", 290 | ")" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 57, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "ADMISSION B-OCCURRENCE\n", 303 | "DATE O\n", 304 | ": O\n", 305 | "\n", 306 | "12/13/2002 O\n", 307 | "\n", 308 | "DISCHARGE B-OCCURRENCE\n", 309 | "DATE O\n", 310 | ": O\n", 311 | "\n", 312 | "12/14/2002 O\n", 313 | "\n", 314 | "HISTORY O\n", 315 | "OF O\n", 316 | "PRESENT O\n", 317 | "ILLNESS O\n", 318 | ": O\n", 319 | "\n", 320 | "The O\n", 321 | "patient O\n", 322 | "is O\n", 323 | "a O\n", 324 | "30 O\n", 325 | "-year-old O\n", 326 | "Gravida B-OCCURRENCE\n", 327 | "III O\n", 328 | ", O\n", 329 | "Para B-OCCURRENCE\n", 330 | "II O\n", 331 | "who O\n", 332 | "presented B-EVIDENTIAL\n", 333 | "desiring O\n", 334 | "definitive B-TREATMENT\n", 335 | "surgical I-TREATMENT\n", 336 | "sterilization I-TREATMENT\n", 337 | ". O\n", 338 | "\n", 339 | "She O\n", 340 | "was O\n", 341 | "extensively B-OCCURRENCE\n", 342 | "counseled I-OCCURRENCE\n", 343 | "regarding O\n", 344 | "other B-TREATMENT\n", 345 | "methods I-TREATMENT\n", 346 | "of O\n", 347 | "birth B-TREATMENT\n", 348 | "control I-TREATMEN\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "print(final_train_text[:500])" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 58, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "Admission B-OCCURRENCE\n", 366 | "Date O\n", 367 | ": O\n", 368 | "\n", 369 | "2010-02-05 O\n", 370 | "\n", 371 | "Discharge B-OCCURRENCE\n", 372 | "Date O\n", 373 | ": O\n", 374 | "\n", 375 | "2010-02-06 O\n", 376 | "\n", 377 | "Service O\n", 378 | ": O\n", 379 | "\n", 380 | "TRAUMA O\n", 381 | "\n", 382 | "HISTORY O\n", 383 | "OF O\n", 384 | "PRESENT O\n", 385 | "ILLNESS O\n", 386 | ": O\n", 387 | "\n", 388 | "The O\n", 389 | "patient O\n", 390 | "is O\n", 391 | "a O\n", 392 | "23 O\n", 393 | "year O\n", 394 | "old O\n", 395 | "female O\n", 396 | ", O\n", 397 | "status O\n", 398 | "post O\n", 399 | "fall B-PROBLEM\n", 400 | "from O\n", 401 | "standing O\n", 402 | "position O\n", 403 | "after O\n", 404 | "slipping O\n", 405 | "on O\n", 406 | "ice O\n", 407 | ". O\n", 408 | "\n", 409 | "She O\n", 410 | "had O\n", 411 | "no O\n", 412 | "loss B-PROBLEM\n", 413 | "of I-PROBLEM\n", 414 | "consciousness I-PROBLEM\n", 415 | ". O\n", 416 | "\n", 417 | "She O\n", 418 | "recall\n" 419 | ] 420 | } 421 | ], 422 | "source": [ 423 | "print(final_dev_text[:400])" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 59, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "ADMISSION B-OCCURRENCE\n", 436 | "DATE O\n", 437 | ": O\n", 438 | "\n", 439 | "10/14/96 O\n", 440 | "\n", 441 | "DISCHARGE B-OCCURRENCE\n", 442 | "DATE O\n", 443 | ": O\n", 444 | "\n", 445 | "10/27/96 O\n", 446 | "date O\n", 447 | "of O\n", 448 | "birth B-OCCURRENCE\n", 449 | "; O\n", 450 | "September O\n", 451 | "30 O\n", 452 | ", O\n", 453 | "1917 O\n", 454 | "\n", 455 | "THER O\n", 456 | "PROCEDURES O\n", 457 | ": O\n", 458 | "\n", 459 | "arterial B-TEST\n", 460 | "catheterization I-TEST\n", 461 | "on O\n", 462 | "10/14/96 O\n", 463 | ", O\n", 464 | "head B-TEST\n", 465 | "CT I-TEST\n", 466 | "scan I-TEST\n", 467 | "on O\n", 468 | "10/14/96 O\n", 469 | "\n", 470 | "HISTORY O\n", 471 | "AND O\n", 472 | "REASON O\n", 473 | "FOR O\n", 474 | "HOSPITALIZATION O\n", 475 | ": O\n", 476 | "\n", 477 | "Granrivern O\n", 478 | "Call O\n", 479 | "is O\n", 480 | "a O\n", 481 | "79-year-old O\n", 482 | "right O\n", 483 | "han\n" 484 | ] 485 | } 486 | ], 487 | "source": [ 488 | "print(test_text[:400])" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 60, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "labels = {}\n", 498 | "for s in final_train_text, final_dev_text, test_text:\n", 499 | " for line in s.split('\\n'):\n", 500 | " if line == '': continue\n", 501 | " label = line.split()[-1]\n", 502 | " assert label == 'O' or label.startswith('B-') or label.startswith('I-'), \"label wrong! %s\" % label\n", 503 | " if label not in labels: labels[label] = 1\n", 504 | " else: labels[label] += 1" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 61, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "data": { 514 | "text/plain": [ 515 | "{'B-OCCURRENCE': 5774,\n", 516 | " 'O': 114910,\n", 517 | " 'B-EVIDENTIAL': 1334,\n", 518 | " 'B-TREATMENT': 7098,\n", 519 | " 'I-TREATMENT': 6748,\n", 520 | " 'I-OCCURRENCE': 3590,\n", 521 | " 'B-CLINICAL_DEPT': 1724,\n", 522 | " 'I-CLINICAL_DEPT': 3253,\n", 523 | " 'B-PROBLEM': 9319,\n", 524 | " 'I-PROBLEM': 13543,\n", 525 | " 'B-TEST': 4762,\n", 526 | " 'I-TEST': 5931,\n", 527 | " 'I-EVIDENTIAL': 84}" 528 | ] 529 | }, 530 | "execution_count": 61, 531 | "metadata": {}, 532 | "output_type": "execute_result" 533 | } 534 | ], 535 | "source": [ 536 | "labels" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 62, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "with open('./processed/train.tsv', mode='w') as f:\n", 546 | " f.write(final_train_text)\n", 547 | "with open('./processed/dev.tsv', mode='w') as f:\n", 548 | " f.write(final_dev_text)\n", 549 | "with open('./processed/test.tsv', mode='w') as f:\n", 550 | " f.write(test_text)" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [] 559 | } 560 | ], 561 | "metadata": { 562 | "kernelspec": { 563 | "display_name": "Python 3", 564 | "language": "python", 565 | "name": "python3" 566 | }, 567 | "language_info": { 568 | "codemirror_mode": { 569 | "name": "ipython", 570 | "version": 3 571 | }, 572 | "file_extension": ".py", 573 | "mimetype": "text/x-python", 574 | "name": "python", 575 | "nbconvert_exporter": "python", 576 | "pygments_lexer": "ipython3", 577 | "version": "3.6.8" 578 | } 579 | }, 580 | "nbformat": 4, 581 | "nbformat_minor": 2 582 | } 583 | -------------------------------------------------------------------------------- /downstream_tasks/i2b2_preprocessing/i2b2_2014_deid_hf_risk/Reformat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, xml.etree.ElementTree as ET, numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "This folder contains a sample of the data that will be released as a part \r\n", 22 | "of the two 2014 i2b2 challenge tracks. \r\n", 23 | "\r\n", 24 | "Each track has its own folder of sample data; the contents are outlined \r\n", 25 | "below.\r\n", 26 | "\r\n", 27 | "----------------------------------------------\r\n", 28 | "\r\n", 29 | "Track 1: De-identification \r\n", 30 | "\r\n", 31 | "Folder contents:\r\n", 32 | "\r\n", 33 | "PHI/\r\n", 34 | "\tThis folder contains only PHI-related annotations for the text \r\n", 35 | "in each document. These files are the gold standard for the de-id track.\r\n", 36 | "\r\n", 37 | "De-identification_guidelines_2014_distribution.pdf\r\n", 38 | "\tThis file contains the guidelines for PHI identification\r\n", 39 | "\r\n", 40 | "de-idi2b2_distribution.dtd\r\n", 41 | "\tThis file is the DTD that describes the valid tags for the \r\n", 42 | "de-identification annotation task.\r\n", 43 | "\r\n", 44 | "---------------------------------------------------\r\n", 45 | "\r\n", 46 | "Track 2: Identifying risk factors for heart disease over time\r\n", 47 | "\r\n", 48 | "Folder contents:\r\n", 49 | "\r\n", 50 | "gold/\r\n", 51 | "\tThis folder contains the gold standard annotations that will be used for\r\n", 52 | "evaluation. These are document-level tags which are defined in the annotation\r\n", 53 | "guidelines. Valid tags and attributes are outlined in the provided cardiac risk DTD\r\n", 54 | "\r\n", 55 | "complete/\r\n", 56 | "\tThis folder contains complete annotations for the entire text. It contains \r\n", 57 | "document level annotations. Each document level annotation is supplemented with tags \r\n", 58 | "that show the the evidence found by our annotators for a particular document level \r\n", 59 | "tag. These annotator tags include position information and are included with the hope \r\n", 60 | "that they will ease system development and error analysis.\r\n", 61 | "\r\n", 62 | "i2b2_2014_annotation_guidelines_distribution.pdf\r\n", 63 | "\tThis file contains the guidelines for the risk factor annotation \r\n", 64 | "\r\n", 65 | "cardiacRisk_distribution.dtd\r\n", 66 | "\tThis file is the DTD that describes the valid tags for the \r\n", 67 | "risk factor annotation task\r\n", 68 | "\r\n", 69 | "---------------------------------------------------\r\n", 70 | "\r\n", 71 | "Overview of the sample files:\r\n", 72 | "\r\n", 73 | "The gold/, complete/, and PHI/ folders each contain 8 XML files. These\r\n", 74 | "files contain four discharge summaries each for two different patients. The file\r\n", 75 | "names follow a consistent pattern with the first set of digits identifying the\r\n", 76 | "patient and the last set of digits identifying the sequential record number\r\n", 77 | "\r\n", 78 | "ie: XXX-YY.xml \r\n", 79 | "where XXX is the patient number, and YY is the record number.\r\n", 80 | "\r\n", 81 | "Example: 320-03.xml\r\n", 82 | "This is the third (03) record for patient 320\r\n", 83 | "\r\n", 84 | "Each file has a root level xml node which will contain a\r\n", 85 | " node that holds the medical annotation text and a node containing\r\n", 86 | "annotations for the document text. The specific annotations contained in each file \r\n", 87 | "are described by the accompanying DTDs and annotation guidelines.\r\n", 88 | "\r\n", 89 | "---------------------------------------------------\r\n", 90 | "\r\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "cat readme.txt" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "START_CDATA = \"\"\n", 106 | "\n", 107 | "TAGS = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'PHI', 'FAMILY_HIST']\n", 108 | "\n", 109 | "def read_xml_file(xml_path, PHI_tag_type='ALL_CHILDREN', match_text=True):\n", 110 | " with open(xml_path, mode='r') as f:\n", 111 | " lines = f.readlines()\n", 112 | " text, in_text = [], False\n", 113 | " for i, l in enumerate(lines):\n", 114 | " if START_CDATA in l:\n", 115 | " text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))\n", 116 | " in_text = True\n", 117 | " elif END_CDATA in l:\n", 118 | " text.append(list(l[:l.find(END_CDATA)]))\n", 119 | " break\n", 120 | " elif in_text:\n", 121 | " if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:\n", 122 | " print(\"Adjusting known error\")\n", 123 | " l = l[:9] + ' ' * 4 + l[9:]\n", 124 | "# elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:\n", 125 | "# print(\"Adjusting known error\")\n", 126 | "# l = l.replace('&', 'and')\n", 127 | " text.append(list(l))\n", 128 | " \n", 129 | " pos_transformer = {}\n", 130 | " \n", 131 | " linear_pos = 1\n", 132 | " for line, sentence in enumerate(text):\n", 133 | " for char_pos, char in enumerate(sentence):\n", 134 | " pos_transformer[linear_pos] = (line, char_pos)\n", 135 | " linear_pos += 1\n", 136 | " \n", 137 | " xml_parsed = ET.parse(xml_path)\n", 138 | " tag_containers = xml_parsed.findall('TAGS')\n", 139 | " assert len(tag_containers) == 1, \"Found multiple tag sets!\"\n", 140 | " tag_container = tag_containers[0]\n", 141 | " \n", 142 | " PHI_tags = tag_container.getchildren() if PHI_tag_type == 'ALL_CHILDREN' else tag_container.findall('PHI')\n", 143 | " PHI_labels = [['O'] * len(sentence) for sentence in text]\n", 144 | " for PHI_tag in PHI_tags:\n", 145 | " base_label = PHI_tag.attrib['TYPE']\n", 146 | " start_pos, end_pos, PHI_text = PHI_tag.attrib['start'], PHI_tag.attrib['end'], PHI_tag.attrib['text']\n", 147 | " start_pos, end_pos = int(start_pos)+1, int(end_pos)\n", 148 | " PHI_text = ' '.join(PHI_text.split())\n", 149 | "# if PHI_text == \"0808 O’neil’s Court\":\n", 150 | "# print(\"Adjusting known error\")\n", 151 | "# end_pos -= 4\n", 152 | " if PHI_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):\n", 153 | " print(\"Adjusting known error\")\n", 154 | " PHI_text = 'Johnson & Johnson'\n", 155 | " \n", 156 | "\n", 157 | " (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]\n", 158 | " \n", 159 | " obs_text = []\n", 160 | " for line in range(start_line, end_line+1):\n", 161 | " t = text[line]\n", 162 | " s = start_char if line == start_line else 0\n", 163 | " e = end_char if line == end_line else len(t)\n", 164 | " obs_text.append(''.join(t[s:e+1]).strip())\n", 165 | " obs_text = ' '.join(obs_text)\n", 166 | " obs_text = ' '.join(obs_text.split())\n", 167 | " \n", 168 | " if match_text: assert obs_text == PHI_text, (\n", 169 | " (\"Texts don't match! %s v %s\" % (PHI_text, obs_text)) + '\\n' + str((\n", 170 | " start_pos, end_pos, line, s, e, t, xml_path\n", 171 | " ))\n", 172 | " )\n", 173 | " \n", 174 | " PHI_labels[end_line][end_char] = 'I-%s' % base_label\n", 175 | " PHI_labels[start_line][start_char] = 'B-%s' % base_label\n", 176 | " \n", 177 | " for line in range(start_line, end_line+1):\n", 178 | " t = text[line]\n", 179 | " s = start_char+1 if line == start_line else 0\n", 180 | " e = end_char-1 if line == end_line else len(t)-1\n", 181 | " for i in range(s, e+1): PHI_labels[line][i] = 'I-%s' % base_label\n", 182 | "\n", 183 | " return text, PHI_labels\n", 184 | " \n", 185 | "def merge_into_words(text_by_char, all_labels_by_char):\n", 186 | " assert len(text_by_char) == len(all_labels_by_char), \"Incorrect # of sentences!\"\n", 187 | " \n", 188 | " N = len(text_by_char)\n", 189 | " \n", 190 | " text_by_word, all_labels_by_word = [], []\n", 191 | " \n", 192 | " for sentence_num in range(N):\n", 193 | " sentence_by_char = text_by_char[sentence_num]\n", 194 | " labels_by_char = all_labels_by_char[sentence_num]\n", 195 | " \n", 196 | " assert len(sentence_by_char) == len(labels_by_char), \"Incorrect # of chars in sentence!\"\n", 197 | " S = len(sentence_by_char)\n", 198 | " \n", 199 | " if labels_by_char == (['O'] * len(sentence_by_char)):\n", 200 | " sentence_by_word = ''.join(sentence_by_char).split()\n", 201 | " labels_by_word = ['O'] * len(sentence_by_word)\n", 202 | " else: \n", 203 | " sentence_by_word, labels_by_word = [], []\n", 204 | " text_chunks, labels_chunks = [], []\n", 205 | " s = 0\n", 206 | " for i in range(S):\n", 207 | " if i == S-1:\n", 208 | " text_chunks.append(sentence_by_char[s:])\n", 209 | " labels_chunks.append(labels_by_char[s:])\n", 210 | " elif labels_by_char[i] == 'O': continue\n", 211 | " else:\n", 212 | " if i > 0 and labels_by_char[i-1] == 'O':\n", 213 | " text_chunks.append(sentence_by_char[s:i])\n", 214 | " labels_chunks.append(labels_by_char[s:i])\n", 215 | " s = i\n", 216 | " if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:\n", 217 | " text_chunks.append(sentence_by_char[s:i+1])\n", 218 | " labels_chunks.append(labels_by_char[s:i+1])\n", 219 | " s = i+1\n", 220 | " \n", 221 | " for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):\n", 222 | " assert len(text_chunk) == len(labels_chunk), \"Bad Chunking (len)\"\n", 223 | " assert len(text_chunk) > 0, \"Bad chunking (len 0)\" + str(text_chunks) + str(labels_chunks)\n", 224 | " \n", 225 | " labels_set = set(labels_chunk)\n", 226 | " assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (\n", 227 | " (\"Bad chunking (contents) %s\" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)\n", 228 | " )\n", 229 | " \n", 230 | " text_chunk_by_word = ''.join(text_chunk).split()\n", 231 | " W = len(text_chunk_by_word)\n", 232 | " if W == 0: \n", 233 | "# assert labels_set == set(['O']), \"0-word chunking and non-0 label!\" + str(\n", 234 | "# text_chunks) + str(labels_chunks\n", 235 | "# )\n", 236 | " continue\n", 237 | " \n", 238 | " if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W\n", 239 | " elif W == 1: labels_chunk_by_word = [labels_chunk[0]]\n", 240 | " elif W == 2: labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]\n", 241 | " else: labels_chunk_by_word = [\n", 242 | " labels_chunk[0]\n", 243 | " ] + [labels_chunk[1]] * (W - 2) + [\n", 244 | " labels_chunk[-1]\n", 245 | " ]\n", 246 | " \n", 247 | " sentence_by_word.extend(text_chunk_by_word)\n", 248 | " labels_by_word.extend(labels_chunk_by_word)\n", 249 | "\n", 250 | " assert len(sentence_by_word) == len(labels_by_word), \"Incorrect # of words in sentence!\" \n", 251 | " \n", 252 | " if len(sentence_by_word) == 0: continue\n", 253 | " \n", 254 | " text_by_word.append(sentence_by_word)\n", 255 | " all_labels_by_word.append(labels_by_word)\n", 256 | " return text_by_word, all_labels_by_word\n", 257 | "\n", 258 | "def reprocess_PHI_labels(folders, base_path='.', PHI_tag_type='PHI', match_text=True, dev_set_size=None):\n", 259 | " all_texts_by_patient, all_labels_by_patient = {}, {}\n", 260 | "\n", 261 | " for folder in folders:\n", 262 | " folder_dir = os.path.join(base_path, folder)\n", 263 | " xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]\n", 264 | " for xml_filename in xml_filenames:\n", 265 | " patient_num = int(xml_filename[:3])\n", 266 | " xml_filepath = os.path.join(folder_dir, xml_filename)\n", 267 | " \n", 268 | " text_by_char, labels_by_char = read_xml_file(\n", 269 | " xml_filepath,\n", 270 | " PHI_tag_type=PHI_tag_type,\n", 271 | " match_text=match_text\n", 272 | " )\n", 273 | " text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)\n", 274 | " \n", 275 | " if patient_num not in all_texts_by_patient:\n", 276 | " all_texts_by_patient[patient_num] = []\n", 277 | " all_labels_by_patient[patient_num] = []\n", 278 | " \n", 279 | " all_texts_by_patient[patient_num].extend(text_by_word)\n", 280 | " all_labels_by_patient[patient_num].extend(labels_by_word)\n", 281 | " \n", 282 | " patients = set(all_texts_by_patient.keys())\n", 283 | " \n", 284 | " if dev_set_size is None: train_patients, dev_patients = list(patients), []\n", 285 | " else:\n", 286 | " N_train = int(len(patients) * (1-dev_set_size))\n", 287 | " patients_random = np.random.permutation(list(patients))\n", 288 | " train_patients = list(patients_random[:N_train])\n", 289 | " dev_patients = list(patients_random[N_train:])\n", 290 | " \n", 291 | " train_texts, train_labels = [], []\n", 292 | " dev_texts, dev_labels = [], []\n", 293 | " \n", 294 | " for patient_num in train_patients:\n", 295 | " train_texts.extend(all_texts_by_patient[patient_num])\n", 296 | " train_labels.extend(all_labels_by_patient[patient_num])\n", 297 | "\n", 298 | " for patient_num in dev_patients:\n", 299 | " dev_texts.extend(all_texts_by_patient[patient_num])\n", 300 | " dev_labels.extend(all_labels_by_patient[patient_num])\n", 301 | "\n", 302 | "\n", 303 | " train_out_text_by_sentence = []\n", 304 | " for text, labels in zip(train_texts, train_labels):\n", 305 | " train_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n", 306 | " dev_out_text_by_sentence = []\n", 307 | " for text, labels in zip(dev_texts, dev_labels):\n", 308 | " dev_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n", 309 | "\n", 310 | " return '\\n\\n'.join(train_out_text_by_sentence), '\\n\\n'.join(dev_out_text_by_sentence)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 4, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "Adjusting known error\n", 323 | "Adjusting known error\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "final_train_text, final_dev_text = reprocess_PHI_labels(\n", 329 | " ['training-PHI-Gold-Set1/', 'training-PHI-Gold-Set2/'], PHI_tag_type='ALL_CHILDREN',\n", 330 | " dev_set_size=0.1, match_text=True\n", 331 | ")" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 5, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "test_text, _ = reprocess_PHI_labels(\n", 341 | " ['testing-PHI-Gold-fixed'], PHI_tag_type='ALL_CHILDREN', match_text=False, dev_set_size=None\n", 342 | ")" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 6, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "Record O\n", 355 | "date: O\n", 356 | "2067-11-30 B-DATE\n", 357 | "\n", 358 | "CARDIOLOGY O\n", 359 | "\n", 360 | "OCHILTREE B-HOSPITAL\n", 361 | "GENERAL I-HOSPITAL\n", 362 | "HOSPITAL I-HOSPITAL\n", 363 | "\n", 364 | "Reason O\n", 365 | "for O\n", 366 | "visit: O\n", 367 | "\n", 368 | "Follow-up O\n", 369 | "appointment O\n", 370 | "for O\n", 371 | "cardiomyopathy O\n", 372 | "\n", 373 | "Interval O\n", 374 | "History: O\n", 375 | "\n", 376 | "Mr O\n", 377 | "Lara B-PATIENT\n", 378 | "reports O\n", 379 | "no O\n", 380 | "problems O\n", 381 | "since O\n", 382 | "the O\n", 383 | "time O\n", 384 | "of O\n", 385 | "his O\n", 386 | "last O\n", 387 | "visit. O\n", 388 | "He O\n", 389 | "insistists O\n", 390 | "that O\n", 391 | "he O\n", 392 | "has O\n", 393 | "absolutely O\n", 394 | "abstained O\n", 395 | "from O\n", 396 | "etoh O\n", 397 | "since O\n", 398 | "his O\n", 399 | "last O\n", 400 | "visit. O\n", 401 | "\n", 402 | "Denies O\n", 403 | "fevers, O\n", 404 | "chills. O\n", 405 | "denies O\n", 406 | "palpitations O\n", 407 | "or O\n", 408 | "syncope. O\n", 409 | "\n", 410 | "Under O\n", 411 | "a O\n", 412 | "fair O\n", 413 | "am\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "print(final_train_text[:500])" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 7, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "Record O\n", 431 | "date: O\n", 432 | "2082-02-28 B-DATE\n", 433 | "\n", 434 | "c.c. O\n", 435 | "Preop O\n", 436 | "PE O\n", 437 | "for O\n", 438 | "cataract O\n", 439 | "extraction O\n", 440 | "on O\n", 441 | "3/7 B-DATE\n", 442 | "\n", 443 | "S: O\n", 444 | "ROS: O\n", 445 | "General: O\n", 446 | "no O\n", 447 | "history O\n", 448 | "of O\n", 449 | "fatigue, O\n", 450 | "weight O\n", 451 | "change, O\n", 452 | "loss O\n", 453 | "of O\n", 454 | "appetite, O\n", 455 | "or O\n", 456 | "weakness. O\n", 457 | "\n", 458 | "HEENT: O\n", 459 | "no O\n", 460 | "history O\n", 461 | "of O\n", 462 | "head O\n", 463 | "injury, O\n", 464 | "glaucoma, O\n", 465 | "tinnitus, O\n", 466 | "vertigo, O\n", 467 | "motion O\n", 468 | "sickness, O\n", 469 | "URI, O\n", 470 | "hearing O\n", 471 | "loss. O\n", 472 | "\n", 473 | "Cardiovascular: O\n", 474 | "+ O\n", 475 | "hypertension O\n", 476 | "- O\n", 477 | "adequate O\n", 478 | "co\n" 479 | ] 480 | } 481 | ], 482 | "source": [ 483 | "print(final_dev_text[:400])" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 8, 489 | "metadata": {}, 490 | "outputs": [ 491 | { 492 | "name": "stdout", 493 | "output_type": "stream", 494 | "text": [ 495 | "Record O\n", 496 | "date: O\n", 497 | "2080-02-18 B-DATE\n", 498 | "\n", 499 | "SDU O\n", 500 | "JAR O\n", 501 | "Admission O\n", 502 | "Note O\n", 503 | "\n", 504 | "Name: O\n", 505 | "Yosef B-PATIENT\n", 506 | "Villegas I-PATIENT\n", 507 | "\n", 508 | "MR: O\n", 509 | "8249813 B-MEDICALRECORD\n", 510 | "\n", 511 | "DOA: O\n", 512 | "2/17/80 B-DATE\n", 513 | "\n", 514 | "PCP: O\n", 515 | "Gilbert B-DOCTOR\n", 516 | "Perez I-DOCTOR\n", 517 | "\n", 518 | "Attending: O\n", 519 | "YBARRA B-DOCTOR\n", 520 | "\n", 521 | "CODE: O\n", 522 | "FULL O\n", 523 | "\n", 524 | "HPI: O\n", 525 | "70 B-AGE\n", 526 | "yo O\n", 527 | "M O\n", 528 | "with O\n", 529 | "NIDDM O\n", 530 | "admitted O\n", 531 | "for O\n", 532 | "cath O\n", 533 | "after O\n", 534 | "positive O\n", 535 | "MIBI. O\n", 536 | "Pt O\n", 537 | "has O\n", 538 | "had O\n", 539 | "increasing O\n", 540 | "CP O\n", 541 | "and O\n", 542 | "SOB O\n", 543 | "on O\n", 544 | "exert\n" 545 | ] 546 | } 547 | ], 548 | "source": [ 549 | "print(test_text[:400])" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": 10, 555 | "metadata": {}, 556 | "outputs": [], 557 | "source": [ 558 | "labels = {}\n", 559 | "for s in final_train_text, final_dev_text, test_text:\n", 560 | " for line in s.split('\\n'):\n", 561 | " if line == '': continue\n", 562 | " label = line.split()[-1]\n", 563 | " assert label == 'O' or label.startswith('B-') or label.startswith('I-'), \"label wrong! %s\" % label\n", 564 | " if label not in labels: labels[label] = 1\n", 565 | " else: labels[label] += 1" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 11, 571 | "metadata": {}, 572 | "outputs": [ 573 | { 574 | "data": { 575 | "text/plain": [ 576 | "{'O': 777860,\n", 577 | " 'B-DATE': 12459,\n", 578 | " 'B-HOSPITAL': 2312,\n", 579 | " 'I-HOSPITAL': 1835,\n", 580 | " 'B-PATIENT': 2195,\n", 581 | " 'B-AGE': 1997,\n", 582 | " 'B-DOCTOR': 4797,\n", 583 | " 'I-DOCTOR': 3482,\n", 584 | " 'I-DATE': 1379,\n", 585 | " 'B-COUNTRY': 183,\n", 586 | " 'B-PROFESSION': 413,\n", 587 | " 'I-PROFESSION': 346,\n", 588 | " 'B-ORGANIZATION': 206,\n", 589 | " 'I-ORGANIZATION': 173,\n", 590 | " 'B-STREET': 352,\n", 591 | " 'I-STREET': 717,\n", 592 | " 'B-CITY': 654,\n", 593 | " 'I-CITY': 171,\n", 594 | " 'B-STATE': 504,\n", 595 | " 'B-ZIP': 352,\n", 596 | " 'I-PATIENT': 1192,\n", 597 | " 'B-MEDICALRECORD': 1033,\n", 598 | " 'I-MEDICALRECORD': 47,\n", 599 | " 'B-IDNUM': 456,\n", 600 | " 'B-USERNAME': 356,\n", 601 | " 'B-PHONE': 524,\n", 602 | " 'I-COUNTRY': 21,\n", 603 | " 'I-PHONE': 100,\n", 604 | " 'I-AGE': 10,\n", 605 | " 'I-IDNUM': 30,\n", 606 | " 'B-FAX': 10,\n", 607 | " 'I-FAX': 2,\n", 608 | " 'B-DEVICE': 15,\n", 609 | " 'B-EMAIL': 5,\n", 610 | " 'B-LOCATION-OTHER': 17,\n", 611 | " 'I-LOCATION-OTHER': 15,\n", 612 | " 'I-STATE': 18,\n", 613 | " 'B-URL': 2,\n", 614 | " 'I-URL': 4,\n", 615 | " 'B-BIOID': 1,\n", 616 | " 'B-HEALTHPLAN': 1,\n", 617 | " 'I-HEALTHPLAN': 1,\n", 618 | " 'I-DEVICE': 2}" 619 | ] 620 | }, 621 | "execution_count": 11, 622 | "metadata": {}, 623 | "output_type": "execute_result" 624 | } 625 | ], 626 | "source": [ 627 | "labels" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 28, 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "with open('./processed/train.tsv', mode='w') as f:\n", 637 | " f.write(final_train_text)\n", 638 | "with open('./processed/dev.tsv', mode='w') as f:\n", 639 | " f.write(final_dev_text)\n", 640 | "with open('./processed/test.tsv', mode='w') as f:\n", 641 | " f.write(test_text)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [] 650 | } 651 | ], 652 | "metadata": { 653 | "kernelspec": { 654 | "display_name": "Python 3", 655 | "language": "python", 656 | "name": "python3" 657 | }, 658 | "language_info": { 659 | "codemirror_mode": { 660 | "name": "ipython", 661 | "version": 3 662 | }, 663 | "file_extension": ".py", 664 | "mimetype": "text/x-python", 665 | "name": "python", 666 | "nbconvert_exporter": "python", 667 | "pygments_lexer": "ipython3", 668 | "version": "3.6.8" 669 | } 670 | }, 671 | "nbformat": 4, 672 | "nbformat_minor": 2 673 | } 674 | -------------------------------------------------------------------------------- /downstream_tasks/run_classifier.py: -------------------------------------------------------------------------------- 1 | # Code is adapted from the PyTorch pretrained BERT repo - See copyright & license below. 2 | 3 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import csv 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 31 | TensorDataset) 32 | from torch.utils.data.distributed import DistributedSampler 33 | from tqdm import tqdm, trange 34 | 35 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 36 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME 37 | from pytorch_pretrained_bert.tokenization import BertTokenizer 38 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 39 | 40 | #added 41 | import json 42 | from random import shuffle 43 | import math 44 | 45 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 46 | datefmt = '%m/%d/%Y %H:%M:%S', 47 | level = logging.INFO) 48 | logger = logging.getLogger(__name__) 49 | 50 | 51 | 52 | class InputFeatures(object): 53 | """A single set of features of data.""" 54 | 55 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 56 | self.input_ids = input_ids 57 | self.input_mask = input_mask 58 | self.segment_ids = segment_ids 59 | self.label_id = label_id 60 | 61 | 62 | class InputExample(object): 63 | """A single training/test example for simple sequence classification.""" 64 | 65 | def __init__(self, guid, text_a, text_b=None, label=None): 66 | """Constructs a InputExample. 67 | 68 | Args: 69 | guid: Unique id for the example. 70 | text_a: string. The untokenized text of the first sequence. For single 71 | sequence tasks, only this sequence must be specified. 72 | text_b: (Optional) string. The untokenized text of the second sequence. 73 | Only must be specified for sequence pair tasks. 74 | label: (Optional) string. The label of the example. This should be 75 | specified for train and dev examples, but not for test examples. 76 | """ 77 | self.guid = guid 78 | self.text_a = text_a 79 | self.text_b = text_b 80 | self.label = label 81 | 82 | 83 | 84 | class DataProcessor(object): 85 | """Base class for data converters for sequence classification data sets.""" 86 | 87 | def get_train_examples(self, data_dir): 88 | """Gets a collection of `InputExample`s for the train set.""" 89 | raise NotImplementedError() 90 | 91 | def get_dev_examples(self, data_dir): 92 | """Gets a collection of `InputExample`s for the dev set.""" 93 | raise NotImplementedError() 94 | 95 | def get_test_examples(self, data_dir): 96 | """Gets a collection of `InputExample`s for the test set.""" 97 | raise NotImplementedError() 98 | 99 | def get_labels(self): 100 | """Gets the list of labels for this data set.""" 101 | raise NotImplementedError() 102 | 103 | @classmethod 104 | def _read_tsv(cls, input_file, quotechar=None): 105 | """Reads a tab separated value file.""" 106 | with open(input_file, "r") as f: 107 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 108 | lines = [] 109 | for line in reader: 110 | if sys.version_info[0] == 2: 111 | line = list(unicode(cell, 'utf-8') for cell in line) 112 | lines.append(line) 113 | return lines 114 | 115 | 116 | class MrpcProcessor(DataProcessor): 117 | """Processor for the MRPC data set (GLUE version).""" 118 | 119 | def get_train_examples(self, data_dir): 120 | """See base class.""" 121 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 122 | return self._create_examples( 123 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 124 | 125 | def get_dev_examples(self, data_dir): 126 | """See base class.""" 127 | return self._create_examples( 128 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 129 | 130 | def get_labels(self): 131 | """See base class.""" 132 | return ["0", "1"] 133 | 134 | def _create_examples(self, lines, set_type): 135 | """Creates examples for the training and dev sets.""" 136 | examples = [] 137 | for (i, line) in enumerate(lines): 138 | if i == 0: 139 | continue 140 | guid = "%s-%s" % (set_type, i) 141 | text_a = line[3] 142 | text_b = line[4] 143 | label = line[0] 144 | examples.append( 145 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 146 | return examples 147 | 148 | 149 | class MnliProcessor(DataProcessor): 150 | """Processor for the MultiNLI data set (GLUE version).""" 151 | 152 | def get_train_examples(self, data_dir): 153 | """See base class.""" 154 | return self._create_examples( 155 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 156 | 157 | def get_dev_examples(self, data_dir): 158 | """See base class.""" 159 | return self._create_examples( 160 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 161 | "dev_matched") 162 | 163 | def get_labels(self): 164 | """See base class.""" 165 | return ["contradiction", "entailment", "neutral"] 166 | 167 | def _create_examples(self, lines, set_type): 168 | """Creates examples for the training and dev sets.""" 169 | examples = [] 170 | for (i, line) in enumerate(lines): 171 | if i == 0: 172 | continue 173 | guid = "%s-%s" % (set_type, line[0]) 174 | text_a = line[8] 175 | text_b = line[9] 176 | label = line[-1] 177 | examples.append( 178 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 179 | return examples 180 | 181 | 182 | class ColaProcessor(DataProcessor): 183 | """Processor for the CoLA data set (GLUE version).""" 184 | 185 | def get_train_examples(self, data_dir): 186 | """See base class.""" 187 | return self._create_examples( 188 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 189 | 190 | def get_dev_examples(self, data_dir): 191 | """See base class.""" 192 | return self._create_examples( 193 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 194 | 195 | def get_labels(self): 196 | """See base class.""" 197 | return ["0", "1"] 198 | 199 | def _create_examples(self, lines, set_type): 200 | """Creates examples for the training and dev sets.""" 201 | examples = [] 202 | for (i, line) in enumerate(lines): 203 | guid = "%s-%s" % (set_type, i) 204 | text_a = line[3] 205 | label = line[1] 206 | examples.append( 207 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 208 | return examples 209 | 210 | # NEW 211 | 212 | 213 | class MedNLIProcessor(DataProcessor): 214 | def _chunks(self, l, n): 215 | """Yield successive n-sized chunks from l.""" 216 | for i in range(0, len(l), n): 217 | yield l[i:i + n] 218 | 219 | 220 | def get_train_examples(self, data_dir): 221 | """Gets a collection of `InputExample`s for the train set.""" 222 | file_path = os.path.join(data_dir, "mli_train_v1.jsonl") 223 | return self._create_examples(file_path) 224 | 225 | def get_dev_examples(self, data_dir): 226 | """Gets a collection of `InputExample`s for the dev set.""" 227 | file_path = os.path.join(data_dir, "mli_dev_v1.jsonl") 228 | return self._create_examples(file_path) 229 | 230 | def get_test_examples(self, data_dir): 231 | """Gets a collection of `InputExample`s for the test set.""" 232 | file_path = os.path.join(data_dir, "mli_test_v1.jsonl") 233 | return self._create_examples(file_path) 234 | 235 | 236 | 237 | 238 | def get_labels(self): 239 | """See base class.""" 240 | return ["contradiction", "entailment", "neutral"] 241 | 242 | def _create_examples(self, file_path): 243 | examples = [] 244 | with open(file_path, "r") as f: 245 | lines = f.readlines() 246 | for line in lines: 247 | example = json.loads(line) 248 | examples.append( 249 | InputExample(guid=example['pairID'], text_a=example['sentence1'], 250 | text_b=example['sentence2'], label=example['gold_label'])) 251 | 252 | return examples 253 | 254 | 255 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 256 | """Loads a data file into a list of `InputBatch`s.""" 257 | 258 | label_map = {label : i for i, label in enumerate(label_list)} 259 | 260 | features = [] 261 | max_len = 0 262 | for (ex_index, example) in enumerate(examples): 263 | tokens_a = tokenizer.tokenize(example.text_a) 264 | 265 | tokens_b = None 266 | if example.text_b: 267 | tokens_b = tokenizer.tokenize(example.text_b) 268 | seq_len = len(tokens_a) + len(tokens_b) 269 | 270 | # Modifies `tokens_a` and `tokens_b` in place so that the total 271 | # length is less than the specified length. 272 | # Account for [CLS], [SEP], [SEP] with "- 3" 273 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 274 | else: 275 | seq_len = len(tokens_a) 276 | # Account for [CLS] and [SEP] with "- 2" 277 | if len(tokens_a) > max_seq_length - 2: 278 | tokens_a = tokens_a[:(max_seq_length - 2)] 279 | 280 | if seq_len > max_len: 281 | max_len = seq_len 282 | # The convention in BERT is: 283 | # (a) For sequence pairs: 284 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 285 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 286 | # (b) For single sequences: 287 | # tokens: [CLS] the dog is hairy . [SEP] 288 | # type_ids: 0 0 0 0 0 0 0 289 | # 290 | # Where "type_ids" are used to indicate whether this is the first 291 | # sequence or the second sequence. The embedding vectors for `type=0` and 292 | # `type=1` were learned during pre-training and are added to the wordpiece 293 | # embedding vector (and position vector). This is not *strictly* necessary 294 | # since the [SEP] token unambigiously separates the sequences, but it makes 295 | # it easier for the model to learn the concept of sequences. 296 | # 297 | # For classification tasks, the first vector (corresponding to [CLS]) is 298 | # used as as the "sentence vector". Note that this only makes sense because 299 | # the entire model is fine-tuned. 300 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 301 | segment_ids = [0] * len(tokens) 302 | 303 | if tokens_b: 304 | tokens += tokens_b + ["[SEP]"] 305 | segment_ids += [1] * (len(tokens_b) + 1) 306 | 307 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 308 | 309 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 310 | # tokens are attended to. 311 | input_mask = [1] * len(input_ids) 312 | 313 | # Zero-pad up to the sequence length. 314 | padding = [0] * (max_seq_length - len(input_ids)) 315 | input_ids += padding 316 | input_mask += padding 317 | segment_ids += padding 318 | 319 | assert len(input_ids) == max_seq_length 320 | assert len(input_mask) == max_seq_length 321 | assert len(segment_ids) == max_seq_length 322 | 323 | label_id = label_map[example.label] 324 | if ex_index < 3: 325 | logger.info("*** Example ***") 326 | logger.info("guid: %s" % (example.guid)) 327 | logger.info("tokens: %s" % " ".join( 328 | [str(x) for x in tokens])) 329 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 330 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 331 | logger.info( 332 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 333 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 334 | 335 | features.append( 336 | InputFeatures(input_ids=input_ids, 337 | input_mask=input_mask, 338 | segment_ids=segment_ids, 339 | label_id=label_id)) 340 | 341 | print('Max Sequence Length: %d' %max_len) 342 | 343 | return features 344 | 345 | 346 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 347 | """Truncates a sequence pair in place to the maximum length.""" 348 | 349 | # This is a simple heuristic which will always truncate the longer sequence 350 | # one token at a time. This makes more sense than truncating an equal percent 351 | # of tokens from each, since if one sequence is very short then each token 352 | # that's truncated likely contains more information than a longer sequence. 353 | while True: 354 | total_length = len(tokens_a) + len(tokens_b) 355 | if total_length <= max_length: 356 | break 357 | if len(tokens_a) > len(tokens_b): 358 | tokens_a.pop() 359 | else: 360 | tokens_b.pop() 361 | 362 | def accuracy(out, labels): 363 | outputs = np.argmax(out, axis=1) 364 | return np.sum(outputs == labels) 365 | 366 | def setup_parser(): 367 | parser = argparse.ArgumentParser() 368 | 369 | ## Required parameters 370 | parser.add_argument("--data_dir", 371 | default=None, 372 | type=str, 373 | required=True, 374 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 375 | parser.add_argument("--bert_model", default=None, type=str, required=True, 376 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 377 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 378 | "bert-base-multilingual-cased, bert-base-chinese, biobert.") 379 | parser.add_argument("--task_name", 380 | default=None, 381 | type=str, 382 | required=True, 383 | help="The name of the task to train.") 384 | parser.add_argument("--output_dir", 385 | default=None, 386 | type=str, 387 | required=True, 388 | help="The output directory where the model predictions and checkpoints will be written.") 389 | 390 | ## Other parameters 391 | parser.add_argument("--cache_dir", 392 | default="", 393 | type=str, 394 | help="Where do you want to store the pre-trained models downloaded from s3") 395 | parser.add_argument("--max_seq_length", 396 | default=128, 397 | type=int, 398 | help="The maximum total input sequence length after WordPiece tokenization. \n" 399 | "Sequences longer than this will be truncated, and sequences shorter \n" 400 | "than this will be padded.") 401 | parser.add_argument("--do_train", 402 | action='store_true', 403 | help="Whether to run training.") 404 | parser.add_argument("--do_eval", 405 | action='store_true', 406 | help="Whether to run eval on the dev set.") 407 | parser.add_argument("--do_test", 408 | action='store_true', 409 | help="Whether to run eval on the test set.") 410 | parser.add_argument("--do_lower_case", 411 | action='store_true', 412 | help="Set this flag if you are using an uncased model.") 413 | parser.add_argument("--train_batch_size", 414 | default=32, 415 | type=int, 416 | help="Total batch size for training.") 417 | parser.add_argument("--eval_batch_size", 418 | default=8, 419 | type=int, 420 | help="Total batch size for eval.") 421 | parser.add_argument("--learning_rate", 422 | default=5e-5, 423 | type=float, 424 | help="The initial learning rate for Adam.") 425 | parser.add_argument("--num_train_epochs", 426 | default=3.0, 427 | type=float, 428 | help="Total number of training epochs to perform.") 429 | parser.add_argument("--warmup_proportion", 430 | default=0.1, 431 | type=float, 432 | help="Proportion of training to perform linear learning rate warmup for. " 433 | "E.g., 0.1 = 10%% of training.") 434 | parser.add_argument("--no_cuda", 435 | action='store_true', 436 | help="Whether not to use CUDA when available") 437 | parser.add_argument("--local_rank", 438 | type=int, 439 | default=-1, 440 | help="local_rank for distributed training on gpus") 441 | parser.add_argument('--seed', 442 | type=int, 443 | default=42, 444 | help="random seed for initialization") 445 | parser.add_argument('--gradient_accumulation_steps', 446 | type=int, 447 | default=1, 448 | help="Number of updates steps to accumulate before performing a backward/update pass.") 449 | parser.add_argument('--fp16', 450 | action='store_true', 451 | help="Whether to use 16-bit float precision instead of 32-bit") 452 | parser.add_argument('--loss_scale', 453 | type=float, default=0, 454 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 455 | "0 (default value): dynamic loss scaling.\n" 456 | "Positive power of 2: static loss scaling value.\n") 457 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 458 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 459 | parser.add_argument('--model_loc', type=str, default='', help="Specify the location of the bio or clinical bert model") 460 | return parser 461 | 462 | def main(): 463 | parser = setup_parser() 464 | args = parser.parse_args() 465 | 466 | # specifies the path where the biobert or clinical bert model is saved 467 | if args.bert_model == 'biobert' or args.bert_model == 'clinical_bert': 468 | args.bert_model = args.model_loc 469 | 470 | print(args.bert_model) 471 | 472 | if args.server_ip and args.server_port: 473 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 474 | import ptvsd 475 | print("Waiting for debugger attach") 476 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 477 | ptvsd.wait_for_attach() 478 | 479 | processors = { 480 | "cola": ColaProcessor, 481 | "mnli": MnliProcessor, 482 | "mrpc": MrpcProcessor, 483 | "mednli": MedNLIProcessor 484 | } 485 | 486 | num_labels_task = { 487 | "cola": 2, 488 | "mnli": 3, 489 | "mrpc": 2, 490 | "mednli": 3 491 | } 492 | 493 | if args.local_rank == -1 or args.no_cuda: 494 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 495 | n_gpu = torch.cuda.device_count() 496 | else: 497 | torch.cuda.set_device(args.local_rank) 498 | device = torch.device("cuda", args.local_rank) 499 | n_gpu = 1 500 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 501 | torch.distributed.init_process_group(backend='nccl') 502 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 503 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 504 | 505 | if args.gradient_accumulation_steps < 1: 506 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 507 | args.gradient_accumulation_steps)) 508 | 509 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 510 | 511 | random.seed(args.seed) 512 | np.random.seed(args.seed) 513 | torch.manual_seed(args.seed) 514 | if n_gpu > 0: 515 | torch.cuda.manual_seed_all(args.seed) 516 | 517 | if not args.do_train and not args.do_eval: 518 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 519 | 520 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 521 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 522 | if not os.path.exists(args.output_dir): 523 | os.makedirs(args.output_dir) 524 | 525 | task_name = args.task_name.lower() 526 | 527 | if task_name not in processors: 528 | raise ValueError("Task not found: %s" % (task_name)) 529 | 530 | processor = processors[task_name]() 531 | num_labels = num_labels_task[task_name] 532 | label_list = processor.get_labels() 533 | 534 | 535 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 536 | 537 | 538 | print('TRAIN') 539 | train = processor.get_train_examples(args.data_dir) 540 | print([(train[i].text_a,train[i].text_b, train[i].label) for i in range(3)]) 541 | print('DEV') 542 | dev = processor.get_dev_examples(args.data_dir) 543 | print([(dev[i].text_a,dev[i].text_b, dev[i].label) for i in range(3)]) 544 | print('TEST') 545 | test = processor.get_test_examples(args.data_dir) 546 | print([(test[i].text_a,test[i].text_b, test[i].label) for i in range(3)]) 547 | 548 | 549 | 550 | train_examples = None 551 | num_train_optimization_steps = None 552 | if args.do_train: 553 | train_examples = processor.get_train_examples(args.data_dir) 554 | num_train_optimization_steps = int( 555 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 556 | if args.local_rank != -1: 557 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 558 | 559 | # Prepare model 560 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)) 561 | model = BertForSequenceClassification.from_pretrained(args.bert_model, 562 | cache_dir=cache_dir, 563 | num_labels = num_labels) 564 | if args.fp16: 565 | model.half() 566 | model.to(device) 567 | if args.local_rank != -1: 568 | try: 569 | from apex.parallel import DistributedDataParallel as DDP 570 | except ImportError: 571 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 572 | 573 | model = DDP(model) 574 | elif n_gpu > 1: 575 | model = torch.nn.DataParallel(model) 576 | 577 | # Prepare optimizer 578 | param_optimizer = list(model.named_parameters()) 579 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 580 | optimizer_grouped_parameters = [ 581 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 582 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 583 | ] 584 | if args.fp16: 585 | try: 586 | from apex.optimizers import FP16_Optimizer 587 | from apex.optimizers import FusedAdam 588 | except ImportError: 589 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 590 | 591 | optimizer = FusedAdam(optimizer_grouped_parameters, 592 | lr=args.learning_rate, 593 | bias_correction=False, 594 | max_grad_norm=1.0) 595 | if args.loss_scale == 0: 596 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 597 | else: 598 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 599 | 600 | else: 601 | optimizer = BertAdam(optimizer_grouped_parameters, 602 | lr=args.learning_rate, 603 | warmup=args.warmup_proportion, 604 | t_total=num_train_optimization_steps) 605 | 606 | global_step = 0 607 | nb_tr_steps = 0 608 | tr_loss = 0 609 | if args.do_train: 610 | train_features = convert_examples_to_features( 611 | train_examples, label_list, args.max_seq_length, tokenizer) 612 | logger.info("***** Running training *****") 613 | logger.info(" Num examples = %d", len(train_examples)) 614 | logger.info(" Batch size = %d", args.train_batch_size) 615 | logger.info(" Num steps = %d", num_train_optimization_steps) 616 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 617 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 618 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 619 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 620 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 621 | if args.local_rank == -1: 622 | train_sampler = RandomSampler(train_data) 623 | else: 624 | train_sampler = DistributedSampler(train_data) 625 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 626 | 627 | model.train() 628 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 629 | tr_loss = 0 630 | nb_tr_examples, nb_tr_steps = 0, 0 631 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 632 | batch = tuple(t.to(device) for t in batch) 633 | input_ids, input_mask, segment_ids, label_ids = batch 634 | loss = model(input_ids, segment_ids, input_mask, label_ids) 635 | if n_gpu > 1: 636 | loss = loss.mean() # mean() to average on multi-gpu. 637 | if args.gradient_accumulation_steps > 1: 638 | loss = loss / args.gradient_accumulation_steps 639 | 640 | if args.fp16: 641 | optimizer.backward(loss) 642 | else: 643 | loss.backward() 644 | 645 | tr_loss += loss.item() 646 | nb_tr_examples += input_ids.size(0) 647 | nb_tr_steps += 1 648 | if (step + 1) % args.gradient_accumulation_steps == 0: 649 | if args.fp16: 650 | # modify learning rate with special warm up BERT uses 651 | # if args.fp16 is False, BertAdam is used that handles this automatically 652 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) 653 | for param_group in optimizer.param_groups: 654 | param_group['lr'] = lr_this_step 655 | optimizer.step() 656 | optimizer.zero_grad() 657 | global_step += 1 658 | 659 | if args.do_train: 660 | # Save a trained model and the associated configuration 661 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 662 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 663 | torch.save(model_to_save.state_dict(), output_model_file) 664 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 665 | with open(output_config_file, 'w') as f: 666 | f.write(model_to_save.config.to_json_string()) 667 | 668 | # Load a trained model and config that you have fine-tuned 669 | config = BertConfig(output_config_file) 670 | model = BertForSequenceClassification(config, num_labels=num_labels) 671 | model.load_state_dict(torch.load(output_model_file)) 672 | else: 673 | model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) 674 | model.to(device) 675 | 676 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 677 | eval_examples = processor.get_dev_examples(args.data_dir) 678 | eval_features = convert_examples_to_features( 679 | eval_examples, label_list, args.max_seq_length, tokenizer) 680 | logger.info("***** Running evaluation *****") 681 | logger.info(" Num examples = %d", len(eval_examples)) 682 | logger.info(" Batch size = %d", args.eval_batch_size) 683 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 684 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 685 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 686 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 687 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 688 | # Run prediction for full data 689 | eval_sampler = SequentialSampler(eval_data) 690 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 691 | 692 | model.eval() 693 | eval_loss, eval_accuracy = 0, 0 694 | nb_eval_steps, nb_eval_examples = 0, 0 695 | 696 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 697 | input_ids = input_ids.to(device) 698 | input_mask = input_mask.to(device) 699 | segment_ids = segment_ids.to(device) 700 | label_ids = label_ids.to(device) 701 | 702 | with torch.no_grad(): 703 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 704 | logits = model(input_ids, segment_ids, input_mask) 705 | 706 | logits = logits.detach().cpu().numpy() 707 | label_ids = label_ids.to('cpu').numpy() 708 | tmp_eval_accuracy = accuracy(logits, label_ids) 709 | 710 | eval_loss += tmp_eval_loss.mean().item() 711 | eval_accuracy += tmp_eval_accuracy 712 | 713 | nb_eval_examples += input_ids.size(0) 714 | nb_eval_steps += 1 715 | 716 | eval_loss = eval_loss / nb_eval_steps 717 | eval_accuracy = eval_accuracy / nb_eval_examples 718 | loss = tr_loss/nb_tr_steps if args.do_train else None 719 | result = {'eval_loss': eval_loss, 720 | 'eval_accuracy': eval_accuracy, 721 | 'global_step': global_step, 722 | 'loss': loss} 723 | 724 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 725 | with open(output_eval_file, "w") as writer: 726 | logger.info("***** Eval results *****") 727 | for key in sorted(result.keys()): 728 | logger.info(" %s = %s", key, str(result[key])) 729 | writer.write("%s = %s\n" % (key, str(result[key]))) 730 | 731 | if args.do_test and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 732 | test_examples = processor.get_test_examples(args.data_dir) 733 | test_features = convert_examples_to_features( 734 | test_examples, label_list, args.max_seq_length, tokenizer) 735 | logger.info("***** Running testing *****") 736 | logger.info(" Num examples = %d", len(test_examples)) 737 | logger.info(" Batch size = %d", args.eval_batch_size) 738 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 739 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 740 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 741 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 742 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 743 | # Run prediction for full data 744 | test_sampler = SequentialSampler(test_data) 745 | test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) 746 | 747 | model.eval() 748 | test_loss, test_accuracy = 0, 0 749 | nb_test_steps, nb_test_examples = 0, 0 750 | 751 | for input_ids, input_mask, segment_ids, label_ids in tqdm(test_dataloader, desc="Testing"): 752 | input_ids = input_ids.to(device) 753 | input_mask = input_mask.to(device) 754 | segment_ids = segment_ids.to(device) 755 | label_ids = label_ids.to(device) 756 | 757 | with torch.no_grad(): 758 | tmp_test_loss = model(input_ids, segment_ids, input_mask, label_ids) 759 | logits = model(input_ids, segment_ids, input_mask) 760 | 761 | logits = logits.detach().cpu().numpy() 762 | label_ids = label_ids.to('cpu').numpy() 763 | tmp_test_accuracy = accuracy(logits, label_ids) 764 | 765 | test_loss += tmp_test_loss.mean().item() 766 | test_accuracy += tmp_test_accuracy 767 | 768 | nb_test_examples += input_ids.size(0) 769 | nb_test_steps += 1 770 | 771 | test_loss = test_loss / nb_test_steps 772 | test_accuracy = test_accuracy / nb_test_examples 773 | loss = tr_loss/nb_tr_steps if args.do_train else None 774 | result = {'test_loss': test_loss, 775 | 'test_accuracy': test_accuracy, 776 | 'global_step': global_step, 777 | 'loss': loss} 778 | 779 | output_test_file = os.path.join(args.output_dir, "test_results.txt") 780 | with open(output_test_file, "w") as writer: 781 | logger.info("***** Test results *****") 782 | for key in sorted(result.keys()): 783 | logger.info(" %s = %s", key, str(result[key])) 784 | writer.write("%s = %s\n" % (key, str(result[key]))) 785 | 786 | if __name__ == "__main__": 787 | main() 788 | -------------------------------------------------------------------------------- /downstream_tasks/run_ner.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | This code is adapted from the kyzhouhzau/BERT-NER repo with several modifications. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import sys 11 | 12 | import collections 13 | import os 14 | import math 15 | import modeling 16 | import optimization 17 | import tokenization 18 | import tensorflow as tf 19 | from tensorflow.python.ops import math_ops 20 | import tf_metrics 21 | import pickle 22 | import numpy as np 23 | import itertools 24 | import json 25 | from random import shuffle 26 | import random 27 | 28 | flags = tf.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string( 33 | "task_name", None, "The name of the task to train." 34 | ) 35 | 36 | flags.DEFINE_string( 37 | "data_dir", None, 38 | "The input datadir.", 39 | ) 40 | 41 | flags.DEFINE_string( 42 | "output_dir", None, 43 | "The output directory where the model checkpoints will be written." 44 | ) 45 | 46 | flags.DEFINE_string( 47 | "bert_config_file", None, 48 | "The config json file corresponding to the pre-trained BERT model." 49 | ) 50 | 51 | flags.DEFINE_string("vocab_file", None, 52 | "The vocabulary file that the BERT model was trained on.") 53 | 54 | flags.DEFINE_string( 55 | "init_checkpoint", None, 56 | "Initial checkpoint (usually from a pre-trained BERT model)." 57 | ) 58 | 59 | flags.DEFINE_bool( 60 | "do_lower_case", False, 61 | "Whether to lower case the input text." 62 | ) 63 | 64 | flags.DEFINE_integer( 65 | "max_seq_length", 128, 66 | "The maximum total input sequence length after WordPiece tokenization." 67 | ) 68 | 69 | flags.DEFINE_bool( 70 | "do_train", True, 71 | "Whether to run training." 72 | ) 73 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 74 | 75 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 76 | 77 | flags.DEFINE_bool("do_predict", True,"Whether to run the model in inference mode on the test set.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 80 | 81 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 82 | 83 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 84 | 85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 86 | 87 | flags.DEFINE_float("num_train_epochs", 10.0, "Total number of training epochs to perform.") 88 | 89 | flags.DEFINE_float( 90 | "warmup_proportion", 0.1, 91 | "Proportion of training to perform linear learning rate warmup for. " 92 | "E.g., 0.1 = 10% of training.") 93 | 94 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 95 | "How often to save the model checkpoint.") 96 | 97 | flags.DEFINE_integer("keep_checkpoint_max", 1, 98 | "How many model checkpoints to keep.") 99 | 100 | flags.DEFINE_integer("iterations_per_loop", 1000, 101 | "How many steps to make in each estimator call.") 102 | 103 | flags.DEFINE_integer("cross_val_sz", 10, "Number of cross validation folds") 104 | 105 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 106 | 107 | flags.DEFINE_integer( 108 | "num_tpu_cores", 8, 109 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 110 | 111 | class InputExample(object): 112 | """A single training/test example for simple token classification.""" 113 | 114 | def __init__(self, guid, tokens, text=None, labels=None): 115 | """Constructs a InputExample. 116 | 117 | Args: 118 | guid: Unique id for the example. 119 | text: (Optional) string. The untokenized text of the sequence. 120 | tokens: list of strings. The tokenized sentence. Each token should have a 121 | corresponding label for train and dev samples. 122 | label: (Optional) list of strings. The labels of the example. This should be 123 | specified for train and dev examples, but not for test examples. 124 | """ 125 | self.guid = guid 126 | self.text = text 127 | self.tokens = tokens 128 | self.labels = labels 129 | 130 | def __str__(self): 131 | return self.__repr__() 132 | 133 | def __repr__(self): 134 | l = [ 135 | "id: {}".format(self.guid), 136 | "tokens: {}".format(" ".join(self.tokens)), 137 | ] 138 | if self.text is not None: 139 | l.append("text: {}".format(self.text)) 140 | 141 | if self.labels is not None: 142 | l.append("labels: {}".format(" ".join(self.labels))) 143 | 144 | return ", ".join(l) 145 | 146 | 147 | class InputFeatures(object): 148 | """A single set of features of data.""" 149 | 150 | def __init__(self, input_ids, input_mask, segment_ids, label_ids,): 151 | self.input_ids = input_ids 152 | self.input_mask = input_mask 153 | self.segment_ids = segment_ids 154 | self.label_ids = label_ids 155 | #self.label_mask = label_mask 156 | 157 | class DataProcessor(object): 158 | """Base class for data converters for sequence classification data sets.""" 159 | 160 | def get_train_examples(self, data_dir): 161 | """Gets a collection of `InputExample`s for the train set.""" 162 | raise NotImplementedError() 163 | 164 | def get_dev_examples(self, data_dir): 165 | """Gets a collection of `InputExample`s for the dev set.""" 166 | raise NotImplementedError() 167 | 168 | def get_labels(self): 169 | """Gets the list of labels for this data set.""" 170 | raise NotImplementedError() 171 | 172 | @classmethod 173 | def _read_data(cls, input_file): 174 | """Reads in data where each line has the word and its corresponding 175 | label separated by whitespace. Each sentence is separated by a blank 176 | line. E.g.: 177 | 178 | Identification O 179 | of O 180 | APC2 O 181 | , O 182 | a O 183 | homologue O 184 | of O 185 | the O 186 | adenomatous B-Disease 187 | polyposis I-Disease 188 | coli I-Disease 189 | tumour I-Disease 190 | suppressor O 191 | . O 192 | 193 | The O 194 | adenomatous B-Disease 195 | polyposis I-Disease 196 | ... 197 | """ 198 | with open(input_file) as f: 199 | lines = [] 200 | words = [] 201 | labels = [] 202 | for line in f: 203 | line = line.strip() 204 | if len(line) == 0: #i.e. we're in between sentences 205 | assert len(words) == len(labels) 206 | if len(words) == 0: 207 | continue 208 | lines.append([words, labels]) 209 | words = [] 210 | labels = [] 211 | continue 212 | 213 | word = line.split()[0] 214 | label = line.split()[-1] 215 | words.append(word) 216 | labels.append(label) 217 | 218 | #TODO: see if there's an off by one error here 219 | return lines 220 | 221 | @classmethod 222 | def _create_example(self, lines, set_type): 223 | examples = [] 224 | for (i, line) in enumerate(lines): 225 | guid = "%s-%s" % (set_type, i) 226 | words,labels = line 227 | words = [tokenization.convert_to_unicode(w) for w in words] 228 | labels = [tokenization.convert_to_unicode(l) for l in labels] 229 | examples.append(InputExample(guid=guid, tokens=words, labels=labels)) 230 | return examples 231 | 232 | @classmethod 233 | def _chunks(self, l, n): 234 | """Yield successive n-sized chunks from l.""" 235 | for i in range(0, len(l), n): 236 | yield l[i:i + n] 237 | 238 | def write_cv_to_file(self, evaluation, test, n): 239 | with open(os.path.join(FLAGS.data_dir, str(n) + '_eval'),'w') as w: 240 | for example in evaluation: 241 | for t, l in zip(example.tokens, example.labels): 242 | w.write("%s %s\n" %(t, l)) 243 | w.write("\n") 244 | 245 | 246 | with open(os.path.join(FLAGS.data_dir, str(n) + '_test'),'w') as test_w: 247 | for test_example in test: 248 | for t, l in zip(test_example.tokens, test_example.labels): 249 | 250 | test_w.write("%s %s\n" %(t, l)) 251 | test_w.write("\n") 252 | 253 | 254 | 255 | def get_cv_examples(self, splits, n, cv_sz=10): 256 | # note, when n=9 (10th split), this recovers the original train, dev, test split 257 | 258 | dev = splits[(n-1)%cv_sz] #4 #0 #3 #1 259 | test = splits[n] #0 #1 #4 #2 260 | # print('train ind: %d-%d' %((n+1)%cv_sz, (n-1)%cv_sz)) 261 | # print('dev ind: %d' %((n-1)%cv_sz)) 262 | # print('test ind`: %d' %n) 263 | if (n+1)%cv_sz > (n-1)%cv_sz: 264 | train = splits[:(n-1)%cv_sz] + splits[(n+1)%cv_sz:] 265 | else: 266 | train = splits[(n+1)%cv_sz:(n-1)%cv_sz] #1-3 #2-4 #0-2 #3-0s 267 | train = list(itertools.chain.from_iterable(train)) 268 | print("Train size: %d, dev size: %d, test size: %d, total: %d" %(len(train), len(dev), len(test), (len(train)+len(dev)+len(test)))) 269 | self.write_cv_to_file(dev, test, n) 270 | return(train, dev, test) 271 | 272 | def create_cv_examples(self, data_dir, cv_sz=10): 273 | train_examples = self.get_train_examples(data_dir) 274 | dev_examples = self.get_dev_examples(data_dir) 275 | test_examples = self.get_test_examples(data_dir) 276 | print('num train examples: %d, num eval examples: %d, num test examples: %d' %(len(train_examples), len(dev_examples), len(test_examples))) 277 | print('Total dataset size: %d' %(len(train_examples) + len(dev_examples) + len(test_examples))) 278 | random.seed(42) 279 | train_dev = train_examples + dev_examples 280 | random.shuffle(train_dev) 281 | split_sz = math.ceil(len(train_dev)/(cv_sz-1)) 282 | print('Split size: %d' %split_sz) 283 | splits = list(self._chunks(train_dev, split_sz)) 284 | print('Num splits: %d' %(len(splits) + 1)) 285 | splits = splits + [test_examples] 286 | print('len splits: ', [len(s) for s in splits]) 287 | return splits 288 | 289 | 290 | 291 | class NerProcessor(DataProcessor): 292 | def get_train_examples(self, data_dir): 293 | return self._create_example( 294 | self._read_data(os.path.join(data_dir, "train_dev.tsv")), "train" 295 | ) 296 | 297 | def get_dev_examples(self, data_dir): 298 | return self._create_example( 299 | self._read_data(os.path.join(data_dir, "devel.tsv")), "dev" 300 | ) 301 | 302 | def get_test_examples(self,data_dir): 303 | test_examples = self._create_example( 304 | self._read_data(os.path.join(data_dir, "test.tsv")), "test") 305 | #print(test_examples) 306 | return test_examples 307 | 308 | 309 | def get_labels(self): 310 | return ["B", "I", "O", "X", "[CLS]", "[SEP]"] 311 | 312 | 313 | class NCBIDiseaseProcessor(DataProcessor): 314 | #https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/ 315 | def get_train_examples(self, data_dir): 316 | return self._create_example( 317 | self._read_data(os.path.join(data_dir, "train.tsv")), "train" 318 | ) 319 | 320 | def get_dev_examples(self, data_dir): 321 | return self._create_example( 322 | self._read_data(os.path.join(data_dir, "devel.tsv")), "dev" 323 | ) 324 | 325 | def get_test_examples(self,data_dir): 326 | return self._create_example( 327 | self._read_data(os.path.join(data_dir, "test.tsv")), "test") 328 | 329 | 330 | def get_labels(self): 331 | return ["B", "I", "O", "X", "[CLS]", "[SEP]"] 332 | 333 | 334 | class i2b22010Processor(DataProcessor): 335 | 336 | def get_train_examples(self, data_dir): 337 | return self._create_example( 338 | self._read_data(os.path.join(data_dir, "train.tsv")), "train" 339 | ) 340 | 341 | def get_dev_examples(self, data_dir): 342 | return self._create_example( 343 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev" 344 | ) 345 | 346 | def get_test_examples(self,data_dir): 347 | print('Path: ', os.path.join(data_dir, "test.tsv")) 348 | test_examples = self._create_example( 349 | self._read_data(os.path.join(data_dir, "test.tsv")), "test") 350 | print(test_examples[-5:]) 351 | return test_examples 352 | 353 | def get_labels(self): 354 | return ["B-problem", "I-problem", "B-treatment", "I-treatment", 'B-test', 'I-test', 'O', "X", "[CLS]", "[SEP]"] 355 | 356 | class i2b22006Processor(DataProcessor): 357 | 358 | 359 | def get_train_examples(self, data_dir): 360 | return self._create_example( 361 | self._read_data(os.path.join(data_dir, "train.conll")), "train" 362 | ) 363 | 364 | def get_dev_examples(self, data_dir): 365 | return self._create_example( 366 | self._read_data(os.path.join(data_dir, "dev.conll")), "dev" 367 | ) 368 | 369 | def get_test_examples(self,data_dir): 370 | return self._create_example( 371 | self._read_data(os.path.join(data_dir, "test.conll")), "test" 372 | ) 373 | 374 | def get_labels(self): 375 | return ["B-ID", "I-ID", "B-HOSPITAL", "I-HOSPITAL", 'B-PATIENT', 'I-PATIENT', 'B-PHONE', 'I-PHONE', 376 | 'B-DATE', 'I-DATE', 'B-DOCTOR', 'I-DOCTOR', 'B-LOCATION', 'I-LOCATION', 'B-AGE', 'I-AGE', 377 | 'O', "X", "[CLS]", "[SEP]"] 378 | 379 | class i2b22012Processor(DataProcessor): 380 | 381 | def get_train_examples(self, data_dir): 382 | return self._create_example( 383 | self._read_data(os.path.join(data_dir, "train.tsv")), "train" 384 | ) 385 | 386 | def get_dev_examples(self, data_dir): 387 | return self._create_example( 388 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev" 389 | ) 390 | 391 | def get_test_examples(self,data_dir): 392 | return self._create_example( 393 | self._read_data(os.path.join(data_dir, "test.tsv")), "test" 394 | ) 395 | 396 | def get_labels(self): 397 | return ['B-OCCURRENCE','I-OCCURRENCE','B-EVIDENTIAL','I-EVIDENTIAL','B-TREATMENT','I-TREATMENT','B-CLINICAL_DEPT', 398 | 'I-CLINICAL_DEPT','B-PROBLEM','I-PROBLEM','B-TEST','I-TEST','O', "X", "[CLS]", "[SEP]"] 399 | 400 | class i2b22014Processor(DataProcessor): 401 | 402 | def get_train_examples(self, data_dir): 403 | return self._create_example( 404 | self._read_data(os.path.join(data_dir, "train.tsv")), "train" 405 | ) 406 | 407 | def get_dev_examples(self, data_dir): 408 | return self._create_example( 409 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev" 410 | ) 411 | 412 | def get_test_examples(self,data_dir): 413 | return self._create_example( 414 | self._read_data(os.path.join(data_dir, "test.tsv")), "test" 415 | ) 416 | 417 | def get_labels(self): 418 | return ["B-IDNUM", "I-IDNUM", "B-HOSPITAL", "I-HOSPITAL", 'B-PATIENT', 'I-PATIENT', 'B-PHONE', 'I-PHONE', 419 | 'B-DATE', 'I-DATE', 'B-DOCTOR', 'I-DOCTOR', 'B-LOCATION-OTHER', 'I-LOCATION-OTHER', 'B-AGE', 'I-AGE', 'B-BIOID', 'I-BIOID', 420 | 'B-STATE', 'I-STATE','B-ZIP', 'I-ZIP', 'B-HEALTHPLAN', 'I-HEALTHPLAN', 'B-ORGANIZATION', 'I-ORGANIZATION', 421 | 'B-MEDICALRECORD', 'I-MEDICALRECORD', 'B-CITY', 'I-CITY', 'B-STREET', 'I-STREET', 'B-COUNTRY', 'I-COUNTRY', 422 | 'B-URL', 'I-URL', 423 | 'B-USERNAME', 'I-USERNAME', 'B-PROFESSION', 'I-PROFESSION', 'B-FAX', 'I-FAX', 'B-EMAIL', 'I-EMAIL', 'B-DEVICE', 'I-DEVICE', 424 | 'O', "X", "[CLS]", "[SEP]"] 425 | 426 | def write_tokens(tokens,tok_to_orig_map, mode, cv_iter): 427 | #print('MODE: %s' %mode) 428 | if mode == "test" or mode == "eval": 429 | path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_"+mode+".txt") 430 | wf = open(path,'a') 431 | for token in tokens: 432 | if token!="**NULL**": 433 | wf.write(token+'\n') 434 | wf.write('\n') 435 | wf.close() 436 | with open(os.path.join(FLAGS.output_dir, str(cv_iter) + "_tok_to_orig_map_"+mode+".txt"),'a') as w: 437 | w.write("-1\n") #correspond to [CLS] 438 | for ind in tok_to_orig_map: 439 | w.write(str(ind)+'\n') 440 | w.write("-1\n") #correspond to [SEP] 441 | w.write('\n') 442 | 443 | 444 | 445 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,cv_iter, mode): 446 | label_map = {} 447 | for (i, label) in enumerate(label_list,1): 448 | label_map[label] = i 449 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'wb') as w: 450 | pickle.dump(label_map,w) 451 | orig_tokens = example.tokens 452 | orig_labels = example.labels 453 | tokens = [] 454 | labels = [] 455 | tok_to_orig_map = [] 456 | 457 | 458 | for i, word in enumerate(orig_tokens): 459 | token = tokenizer.tokenize(word) 460 | tokens.extend(token) 461 | orig_label = orig_labels[i] 462 | for m in range(len(token)): 463 | tok_to_orig_map.append(i) 464 | if m == 0: 465 | labels.append(orig_label) 466 | else: 467 | labels.append("X") 468 | 469 | 470 | if len(tokens) >= max_seq_length - 1: 471 | tokens = tokens[0:(max_seq_length - 2)] 472 | labels = labels[0:(max_seq_length - 2)] 473 | ntokens = [] 474 | segment_ids = [] 475 | label_ids = [] 476 | ntokens.append("[CLS]") 477 | segment_ids.append(0) 478 | label_ids.append(label_map["[CLS]"]) 479 | for i, token in enumerate(tokens): 480 | ntokens.append(token) 481 | segment_ids.append(0) 482 | label_ids.append(label_map[labels[i]]) 483 | ntokens.append("[SEP]") 484 | segment_ids.append(0) 485 | label_ids.append(label_map["[SEP]"]) 486 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 487 | input_mask = [1] * len(input_ids) 488 | while len(input_ids) < max_seq_length: 489 | input_ids.append(0) 490 | input_mask.append(0) 491 | segment_ids.append(0) 492 | label_ids.append(0) 493 | ntokens.append("**NULL**") 494 | 495 | assert len(input_ids) == max_seq_length 496 | assert len(input_mask) == max_seq_length 497 | assert len(segment_ids) == max_seq_length 498 | assert len(label_ids) == max_seq_length 499 | 500 | if ex_index < 3: 501 | tf.logging.info("*** Example ***") 502 | tf.logging.info("guid: %s" % (example.guid)) 503 | tf.logging.info("tokens: %s" % " ".join( 504 | [tokenization.printable_text(x) for x in tokens])) 505 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 506 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 507 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 508 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 509 | 510 | feature = InputFeatures( 511 | input_ids=input_ids, 512 | input_mask=input_mask, 513 | segment_ids=segment_ids, 514 | label_ids=label_ids, 515 | ) 516 | write_tokens(ntokens,tok_to_orig_map, mode, cv_iter) 517 | return feature 518 | 519 | 520 | def filed_based_convert_examples_to_features( 521 | examples, label_list, max_seq_length, tokenizer, output_file,cv_iter, mode=None 522 | ): 523 | writer = tf.python_io.TFRecordWriter(output_file) 524 | for (ex_index, example) in enumerate(examples): 525 | if ex_index % 5000 == 0: 526 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 527 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, cv_iter, mode ) 528 | 529 | def create_int_feature(values): 530 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 531 | return f 532 | 533 | features = collections.OrderedDict() 534 | features["input_ids"] = create_int_feature(feature.input_ids) 535 | features["input_mask"] = create_int_feature(feature.input_mask) 536 | features["segment_ids"] = create_int_feature(feature.segment_ids) 537 | features["label_ids"] = create_int_feature(feature.label_ids) 538 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 539 | writer.write(tf_example.SerializeToString()) 540 | 541 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): 542 | name_to_features = { 543 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 544 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 545 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 546 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 547 | 548 | } 549 | 550 | def _decode_record(record, name_to_features): 551 | example = tf.parse_single_example(record, name_to_features) 552 | for name in list(example.keys()): 553 | t = example[name] 554 | if t.dtype == tf.int64: 555 | t = tf.to_int32(t) 556 | example[name] = t 557 | return example 558 | 559 | def input_fn(params): 560 | batch_size = params["batch_size"] 561 | d = tf.data.TFRecordDataset(input_file) 562 | if is_training: 563 | d = d.repeat() 564 | d = d.shuffle(buffer_size=100) 565 | d = d.apply(tf.contrib.data.map_and_batch( 566 | lambda record: _decode_record(record, name_to_features), 567 | batch_size=batch_size, 568 | drop_remainder=drop_remainder 569 | )) 570 | return d 571 | return input_fn 572 | 573 | 574 | def create_model(bert_config, is_training, input_ids, input_mask, 575 | segment_ids, labels, num_labels, use_one_hot_embeddings): 576 | model = modeling.BertModel( 577 | config=bert_config, 578 | is_training=is_training, 579 | input_ids=input_ids, 580 | input_mask=input_mask, 581 | token_type_ids=segment_ids, 582 | use_one_hot_embeddings=use_one_hot_embeddings 583 | ) 584 | 585 | output_layer = model.get_sequence_output() 586 | 587 | hidden_size = output_layer.shape[-1].value 588 | 589 | output_weight = tf.get_variable( 590 | "output_weights", [num_labels, hidden_size], 591 | initializer=tf.truncated_normal_initializer(stddev=0.02) 592 | ) 593 | output_bias = tf.get_variable( 594 | "output_bias", [num_labels], initializer=tf.zeros_initializer() 595 | ) 596 | with tf.variable_scope("loss"): 597 | if is_training: 598 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 599 | output_layer = tf.reshape(output_layer, [-1, hidden_size]) 600 | logits = tf.matmul(output_layer, output_weight, transpose_b=True) 601 | logits = tf.nn.bias_add(logits, output_bias) 602 | logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, num_labels]) 603 | 604 | ########################################################################## 605 | log_probs = tf.nn.log_softmax(logits, axis=-1) 606 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 607 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 608 | loss = tf.reduce_sum(per_example_loss) 609 | probabilities = tf.nn.softmax(logits, axis=-1) 610 | predict = tf.argmax(probabilities,axis=-1) 611 | return (loss, per_example_loss, logits, log_probs, predict) 612 | ########################################################################## 613 | 614 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 615 | num_train_steps, num_warmup_steps, use_tpu, 616 | use_one_hot_embeddings): 617 | def model_fn(features, labels, mode, params): 618 | tf.logging.info("*** Features ***") 619 | for name in sorted(features.keys()): 620 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 621 | input_ids = features["input_ids"] 622 | input_mask = features["input_mask"] 623 | segment_ids = features["segment_ids"] 624 | label_ids = features["label_ids"] 625 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 626 | 627 | (total_loss, per_example_loss, logits, log_probs, predicts) = create_model( 628 | bert_config, is_training, input_ids, input_mask,segment_ids, label_ids, 629 | num_labels, use_one_hot_embeddings) 630 | tvars = tf.trainable_variables() 631 | scaffold_fn = None 632 | if init_checkpoint: 633 | print('INIT_CHECKPOINT') 634 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) 635 | if use_tpu: 636 | def tpu_scaffold(): 637 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 638 | return tf.train.Scaffold() 639 | scaffold_fn = tpu_scaffold 640 | else: 641 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 642 | tf.logging.info("**** Trainable Variables ****") 643 | 644 | for var in tvars: 645 | init_string = "" 646 | if var.name in initialized_variable_names: 647 | init_string = ", *INIT_FROM_CKPT*" 648 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 649 | init_string) 650 | output_spec = None 651 | if mode == tf.estimator.ModeKeys.TRAIN: 652 | train_op = optimization.create_optimizer( 653 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 654 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 655 | mode=mode, 656 | loss=total_loss, 657 | train_op=train_op, 658 | scaffold_fn=scaffold_fn) 659 | elif mode == tf.estimator.ModeKeys.EVAL: 660 | def metric_fn(per_example_loss, label_ids, logits): 661 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 662 | precision = tf_metrics.precision(label_ids,predictions,num_labels,[1,2],average="macro") 663 | recall = tf_metrics.recall(label_ids,predictions,num_labels,[1,2],average="macro") 664 | f = tf_metrics.f1(label_ids,predictions,num_labels,[1,2],average="macro") 665 | # 666 | return { 667 | "eval_precision":precision, 668 | "eval_recall":recall, 669 | "eval_f": f, 670 | } 671 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 672 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 673 | mode=mode, 674 | loss=total_loss, 675 | eval_metrics=eval_metrics, 676 | scaffold_fn=scaffold_fn) 677 | 678 | elif mode == tf.estimator.ModeKeys.PREDICT: 679 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 680 | mode=mode, 681 | predictions={"prediction": predicts, "log_probs": log_probs}, 682 | scaffold_fn=scaffold_fn 683 | ) 684 | return output_spec 685 | return model_fn 686 | 687 | def read_tok_file(token_path): 688 | tokens = list() 689 | with open(token_path, 'r') as reader: 690 | for line in reader: 691 | tok = line.strip() 692 | if tok == '[CLS]': 693 | tmp_toks = [tok] 694 | elif tok == '[SEP]': 695 | tmp_toks.append(tok) 696 | tokens.append(tmp_toks) 697 | elif tok == '': 698 | continue 699 | else: 700 | tmp_toks.append(tok) 701 | return tokens 702 | 703 | def main(_): 704 | tf.logging.set_verbosity(tf.logging.INFO) 705 | 706 | processors = { 707 | "ncbi": NCBIDiseaseProcessor, 708 | "i2b2_2010": i2b22010Processor, 709 | "i2b2_2006": i2b22006Processor, 710 | "i2b2_2014": i2b22014Processor, 711 | "i2b2_2012": i2b22012Processor 712 | } 713 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 714 | raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` must be True.") 715 | 716 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 717 | 718 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 719 | raise ValueError( 720 | "Cannot use sequence length %d because the BERT model " 721 | "was only trained up to sequence length %d" % 722 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 723 | 724 | task_name = FLAGS.task_name.lower() 725 | if task_name not in processors: 726 | raise ValueError("Task not found: %s" % (task_name)) 727 | processor = processors[task_name]() 728 | 729 | label_list = processor.get_labels() 730 | 731 | 732 | splits = processor.create_cv_examples(FLAGS.data_dir, cv_sz=FLAGS.cross_val_sz) 733 | 734 | for cv_iter in range(FLAGS.cross_val_sz): 735 | #for cv_iter in [9]: 736 | # if you only want to use the true train, val, test split, then use the last CV split. 737 | # We ran out of time to do cross validation so we only used the original train/val/test split. 738 | 739 | 740 | tok_eval = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_eval.txt") 741 | tok_test = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_test.txt") 742 | 743 | if os.path.exists(tok_eval): 744 | os.remove(tok_eval) 745 | if os.path.exists(tok_test): 746 | os.remove(tok_test) 747 | tokenizer = tokenization.FullTokenizer( 748 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 749 | tpu_cluster_resolver = None 750 | if FLAGS.use_tpu and FLAGS.tpu_name: 751 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 752 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 753 | 754 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 755 | 756 | run_config = tf.contrib.tpu.RunConfig( 757 | cluster=tpu_cluster_resolver, 758 | master=FLAGS.master, 759 | model_dir=FLAGS.output_dir, 760 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 761 | keep_checkpoint_max=FLAGS.keep_checkpoint_max, 762 | tpu_config=tf.contrib.tpu.TPUConfig( 763 | iterations_per_loop=FLAGS.iterations_per_loop, 764 | num_shards=FLAGS.num_tpu_cores, 765 | per_host_input_for_training=is_per_host)) 766 | 767 | train_examples = None 768 | num_train_steps = None 769 | num_warmup_steps = None 770 | 771 | if FLAGS.do_train: 772 | train_examples, eval_examples, test_examples = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz) 773 | print('train sz: %d, val size: %d, test size: %d' %(len(train_examples), len(eval_examples), len(test_examples))) 774 | num_train_steps = int( 775 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 776 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 777 | 778 | model_fn = model_fn_builder( 779 | bert_config=bert_config, 780 | num_labels=len(label_list)+1, 781 | init_checkpoint=FLAGS.init_checkpoint, 782 | learning_rate=FLAGS.learning_rate, 783 | num_train_steps=num_train_steps, 784 | num_warmup_steps=num_warmup_steps, 785 | use_tpu=FLAGS.use_tpu, 786 | use_one_hot_embeddings=FLAGS.use_tpu) 787 | 788 | estimator = tf.contrib.tpu.TPUEstimator( 789 | use_tpu=FLAGS.use_tpu, 790 | model_fn=model_fn, 791 | config=run_config, 792 | train_batch_size=FLAGS.train_batch_size, 793 | eval_batch_size=FLAGS.eval_batch_size, 794 | predict_batch_size=FLAGS.predict_batch_size) 795 | 796 | if FLAGS.do_train: 797 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 798 | filed_based_convert_examples_to_features( 799 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file, cv_iter) 800 | tf.logging.info("***** Running training *****") 801 | tf.logging.info(" Num examples = %d", len(train_examples)) 802 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 803 | tf.logging.info(" Num steps = %d", num_train_steps) 804 | train_input_fn = file_based_input_fn_builder( 805 | input_file=train_file, 806 | seq_length=FLAGS.max_seq_length, 807 | is_training=True, 808 | drop_remainder=True) 809 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 810 | 811 | if FLAGS.do_eval: 812 | _, eval_examples, _ = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz) 813 | 814 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 815 | filed_based_convert_examples_to_features( 816 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file, cv_iter, mode="eval", ) 817 | 818 | tf.logging.info("***** Running evaluation *****") 819 | tf.logging.info(" Num examples = %d", len(eval_examples)) 820 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 821 | eval_steps = None 822 | if FLAGS.use_tpu: 823 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 824 | eval_drop_remainder = True if FLAGS.use_tpu else False 825 | eval_input_fn = file_based_input_fn_builder( 826 | input_file=eval_file, 827 | seq_length=FLAGS.max_seq_length, 828 | is_training=False, 829 | drop_remainder=eval_drop_remainder) 830 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 831 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 832 | with open(output_eval_file, "w") as writer: 833 | tf.logging.info("***** Eval results *****") 834 | for key in sorted(result.keys()): 835 | tf.logging.info(" %s = %s", key, str(result[key])) 836 | writer.write("%s = %s\n" % (key, str(result[key]))) 837 | 838 | #added 839 | eval_token_path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_eval.txt") 840 | eval_tokens = read_tok_file(eval_token_path) 841 | 842 | eval_result = estimator.predict(input_fn=eval_input_fn) 843 | output_predict_file = os.path.join(FLAGS.output_dir, str(cv_iter) + "_label_eval.txt") 844 | 845 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'rb') as rf: 846 | label2id = pickle.load(rf) 847 | id2label = {value:key for key,value in label2id.items()} 848 | 849 | with open(output_predict_file,'w') as p_writer: 850 | for pidx, prediction in enumerate(eval_result): 851 | slen = len(eval_tokens[pidx]) 852 | output_line = "\n".join(id2label[id] if id!=0 else id2label[3] for id in prediction['prediction'][:slen]) + "\n" #change to O tag 853 | p_writer.write(output_line) 854 | p_writer.write('\n') 855 | 856 | 857 | 858 | if FLAGS.do_predict: 859 | 860 | _,_, predict_examples = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz) 861 | 862 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 863 | filed_based_convert_examples_to_features(predict_examples, label_list, 864 | FLAGS.max_seq_length, tokenizer, 865 | predict_file,cv_iter, mode="test") 866 | 867 | tf.logging.info("***** Running prediction*****") 868 | tf.logging.info(" Num examples = %d", len(predict_examples)) 869 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 870 | if FLAGS.use_tpu: 871 | # Warning: According to tpu_estimator.py Prediction on TPU is an 872 | # experimental feature and hence not supported here 873 | raise ValueError("Prediction in TPU not supported") 874 | predict_drop_remainder = True if FLAGS.use_tpu else False 875 | predict_input_fn = file_based_input_fn_builder( 876 | input_file=predict_file, 877 | seq_length=FLAGS.max_seq_length, 878 | is_training=False, 879 | drop_remainder=predict_drop_remainder) 880 | prf = estimator.evaluate(input_fn=predict_input_fn, steps=None) 881 | tf.logging.info("***** token-level Test evaluation results *****") 882 | for key in sorted(prf.keys()): 883 | tf.logging.info(" %s = %s", key, str(prf[key])) 884 | 885 | test_token_path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_test.txt") 886 | test_tokens = read_tok_file(test_token_path) 887 | 888 | 889 | result = estimator.predict(input_fn=predict_input_fn) 890 | output_predict_file = os.path.join(FLAGS.output_dir, str(cv_iter) + "_label_test.txt") 891 | 892 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'rb') as rf: 893 | label2id = pickle.load(rf) 894 | id2label = {value:key for key,value in label2id.items()} 895 | 896 | with open(output_predict_file,'w') as p_writer: 897 | for pidx, prediction in enumerate(result): 898 | slen = len(test_tokens[pidx]) 899 | output_line = "\n".join(id2label[id] if id!=0 else id2label[3] for id in prediction['prediction'][:slen]) + "\n" #change to O tag 900 | p_writer.write(output_line) 901 | p_writer.write('\n') 902 | 903 | 904 | if __name__ == "__main__": 905 | flags.mark_flag_as_required("data_dir") 906 | flags.mark_flag_as_required("task_name") 907 | flags.mark_flag_as_required("vocab_file") 908 | flags.mark_flag_as_required("bert_config_file") 909 | flags.mark_flag_as_required("output_dir") 910 | tf.app.run() 911 | --------------------------------------------------------------------------------