├── evaluation ├── .Rhistory ├── accuracy_files │ └── figure-gfm │ │ ├── f1-1.png │ │ ├── acc-1.png │ │ ├── mcc-1.png │ │ ├── prc-1.png │ │ ├── roc-1.png │ │ ├── loss-log-1.png │ │ ├── prc long-1.png │ │ ├── plot pooled-1.png │ │ └── radarplots-1.png ├── table1.csv ├── irr.csv ├── auc-ct.csv ├── table5.csv ├── table2.csv ├── auc-for-ct.csv ├── table4.csv ├── table3.csv └── accuracy.Rmd ├── annotator ├── data │ ├── log.txt │ └── file_dir.csv ├── requirements.txt ├── screenshot_application.png ├── README.md ├── simple-annotator.py └── manual-english.md ├── rule-based-classification ├── README.md ├── negative_congestion.csv ├── positive_congestion.csv ├── rule-based-algorithm.Rmd └── rule-based-algorithm.md ├── models ├── README.md └── convert-model-to-pytorch.sh ├── wordpiece-vocabularies ├── README.md ├── handclean-wordpiece-vocabulary.R └── insert-custom-vocabs.R ├── pretraining ├── run_datageneration.sh ├── run_pretraining.sh ├── run_sentencizing.py ├── notebooks │ ├── 03_create-pretraining-data.ipynb │ ├── 02_bert-custom-vocabulary.ipynb │ ├── 04_run-pretraining.ipynb │ └── 01_sentencizing.ipynb ├── bert-vocab-builder │ ├── subword_builder.py │ └── tokenizer.py ├── optimization.py ├── tokenization.py └── create_pretraining_data.py ├── text-extraction ├── clean-report-texts.R └── extract-reports.R ├── README.md └── finetuning ├── 01_binary-classification.ipynb ├── 02_multilabel-classification.ipynb └── loss-logs.txt /evaluation/.Rhistory: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /annotator/data/log.txt: -------------------------------------------------------------------------------- 1 | Filename, annotator, confidence 2 | -------------------------------------------------------------------------------- /annotator/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.2 2 | pandas>=0.25.2 3 | 4 | -------------------------------------------------------------------------------- /annotator/screenshot_application.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/annotator/screenshot_application.png -------------------------------------------------------------------------------- /annotator/data/file_dir.csv: -------------------------------------------------------------------------------- 1 | file_name,annoation1,annoation2,annoation3 2 | sample_id_1,,, 3 | sample_id_2,,, 4 | sample_id_3,,, 5 | sample_id_4,,, -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/f1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/f1-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/acc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/acc-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/mcc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/mcc-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/prc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/prc-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/roc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/roc-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/loss-log-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/loss-log-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/prc long-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/prc long-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/plot pooled-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/plot pooled-1.png -------------------------------------------------------------------------------- /evaluation/accuracy_files/figure-gfm/radarplots-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kbressem/bert-for-radiology/HEAD/evaluation/accuracy_files/figure-gfm/radarplots-1.png -------------------------------------------------------------------------------- /rule-based-classification/README.md: -------------------------------------------------------------------------------- 1 | # Rule Based Classification 2 | 3 | The development of a rule-based algorithm was not pursued any further soon after decision to employ a BERT model was made. 4 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Models created from scripts in this repository will be stored here, if the code is executed on a local machine. As the size of the models varies between 400 MB and 2 GB, our created models were not uploaded to GitHub. They can be downloaded soon at: sharepoint.charite.de 2 | -------------------------------------------------------------------------------- /evaluation/table1.csv: -------------------------------------------------------------------------------- 1 | finding,kappa,p_kappa,agreement 2 | Congestion,0.77,0,91.8 3 | Opacity,0.67,0,83.2 4 | Effusion,0.71,0,85.8 5 | Pneumothorax,0.83,0,97.2 6 | Thoracic drain,0.9,0,96.2 7 | Venous Catheter,0.93,0,96.6 8 | Gastric Tube,0.94,0,97.8 9 | Tracheal tube/canula,0.94,0,97.3 10 | Misplaced Medical Device,0.71,0,97.5 11 | pooled,0.82,0,93.71 12 | -------------------------------------------------------------------------------- /evaluation/irr.csv: -------------------------------------------------------------------------------- 1 | finding,kappa,p_kappa,agreement 2 | Stauung,0.77,0,91.79 3 | Verschattung,0.67,0,83.24 4 | Erguss,0.71,0,85.81 5 | Pneumothorax,0.83,0,97.19 6 | Thoraxdrainage,0.9,0,96.25 7 | ZVK,0.93,0,96.6 8 | Magensonde,0.94,0,97.77 9 | Tubus,0.94,0,97.3 10 | Materialfehllage,0.71,0,97.54 11 | Confidence,0.04,0.0249252551135366,41.28 12 | text,1,0,100 13 | -------------------------------------------------------------------------------- /evaluation/auc-ct.csv: -------------------------------------------------------------------------------- 1 | , ,fsBERT, ,gerBERT, ,mBERT, ,radBERT, 2 | train size, finding, AUC,AUPRC,AUC,AUPRC,AUC,AUPRC,AUC,AUPRC 3 | 4703,Congestion,0.81,0.47,0.76,0.56,0.87,0.65,0.9,0.77 4 | 4703,Opacity,0.6,0.74,0.74,0.93,0.75,0.93,0.81,0.96 5 | 4703,Effusion,0.76,0.75,0.94,0.96,0.95,0.97,0.9,0.93 6 | 4703,Pneumothorax,0.84,0.51,0.78,0.69,0.9,0.76,0.93,0.75 7 | 4703,Thoracic drain,0.91,0.65,0.91,0.76,0.89,0.79,0.93,0.88 8 | 4703,Venous Catheter,0.89,0.84,0.87,0.77,0.86,0.74,0.89,0.86 9 | 4703,Gastric Tube,0.82,0.75,0.74,0.6,0.76,0.61,0.84,0.71 10 | 4703,Tracheal Tube/Canula,0.9,0.87,0.87,0.8,0.89,0.82,0.91,0.9 11 | 4703,Misplaced Medical Device,0.46,0,0.97,0.62,0.96,0.5,0.77,0.4 12 | -------------------------------------------------------------------------------- /evaluation/table5.csv: -------------------------------------------------------------------------------- 1 | X2,fsBERT,X4,gerBERT,X6,mBERT,X8,radBERT,X10 2 | finding,AUC,AUPRC,AUC,AUPRC,AUC,AUPRC,AUC,AUPRC 3 | Congestion,0.81,0.47,0.76,0.56,0.87,0.65,0.9,0.77 4 | Opacity,0.6,0.74,0.74,0.93,0.75,0.93,0.81,0.96 5 | Effusion,0.76,0.75,0.94,0.96,0.95,0.97,0.9,0.93 6 | Pneumothorax,0.84,0.51,0.78,0.69,0.9,0.76,0.93,0.75 7 | Thoracic drain,0.91,0.65,0.91,0.76,0.89,0.79,0.93,0.88 8 | Venous Catheter,0.89,0.84,0.87,0.77,0.86,0.74,0.89,0.86 9 | Gastric Tube,0.82,0.75,0.74,0.6,0.76,0.61,0.84,0.71 10 | Tracheal Tube/Canula,0.9,0.87,0.87,0.8,0.89,0.82,0.91,0.9 11 | Misplaced Medical Device,0.46,0,0.97,0.62,0.96,0.5,0.77,0.4 12 | Pooled,0.78,0.62,0.84,0.74,0.87,0.75,0.88,0.8 13 | -------------------------------------------------------------------------------- /wordpiece-vocabularies/README.md: -------------------------------------------------------------------------------- 1 | # Explanation of files 2 | 3 | _vocab-bert-mincount-1000.txt_ 4 | WordPiece vocabulary. Not further processed. Flag for min_count was set to 1000. 5 | 6 | _vocab-bert-mincount-1000-handcleaned.txt_ 7 | WordPiece vocabulary. Special characters have been removed form tokens (e.g. '(Serie' --> 'Serie'). The same token may appear multiple times. 8 | 9 | _vocab-bert-mincount-5000.txt_ 10 | WordPiece vocabulary. Not further processed. Flag for min_count was set to 5000. 11 | 12 | _vocab-bert-handcleaned-30000.txt_ 13 | Further processed _vocab-bert-mincount-1000-handcleaned.txt_, double tokes have been replaced, `[unusedX]` tokens were inserted to fill the vocabulary up to 30000. 14 | -------------------------------------------------------------------------------- /wordpiece-vocabularies/handclean-wordpiece-vocabulary.R: -------------------------------------------------------------------------------- 1 | # handclean WordPiece Vocabulary 2 | library(tidyverse) 3 | vocab <- read_csv('vocab-bert-1000-handcleaned.txt', col_names = 'token') 4 | # Numbers (like dates), ':', ',', '(' and ')' were removed from the WordPiece vocabulary, leaving duplicated words. 5 | # " and ' were also removed 6 | # e.g. ['(Serie', 'Serie'] now became ['Serie' 'Serie'] 7 | tokens <- vocab %>% select(token) %>% unique() %>% unlist() 8 | 9 | # padding to 30 000 tokes inserting [unusedX] tokens 10 | for (i in 1:(30000-length(tokens))) tokens = c(tokens, paste('[unused', i, ']', sep = '')) 11 | write_csv(data.frame(tokens), 'vocab-bert-handcleaned-30000.txt') 12 | 13 | -------------------------------------------------------------------------------- /evaluation/table2.csv: -------------------------------------------------------------------------------- 1 | finding,prevalence,finding,prevalence,finding,prevalence 2 | all data,,training data,,test data, 3 | Verschattung,60% (n = 3101),Verschattung,60% (n = 2836),Verschattung,53% (n = 265) 4 | Erguss,47% (n = 2461),Erguss,48% (n = 2243),Erguss,44% (n = 218) 5 | Stauung,28% (n = 1446),Stauung,28% (n = 1330),Stauung,23% (n = 116) 6 | Pneumothorax,8% (n = 429),Pneumothorax,8% (n = 390),Pneumothorax,8% (n = 39) 7 | ZVK,58% (n = 3020),ZVK,58% (n = 2732),ZVK,58% (n = 288) 8 | Tubus,41% (n = 2110),Tubus,40% (n = 1902),Tubus,42% (n = 208) 9 | Magensonde,25% (n = 1299),Magensonde,25% (n = 1169),Magensonde,26% (n = 130) 10 | Thoraxdrainage,21% (n = 1077),Thoraxdrainage,21% (n = 969),Thoraxdrainage,22% (n = 108) 11 | Materialfehllage,4% (n = 188),Materialfehllage,3% (n = 164),Materialfehllage,5% (n = 24) 12 | -------------------------------------------------------------------------------- /models/convert-model-to-pytorch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # clone the transformers repository of huggingface (e.g. to documents) 4 | 5 | python ../../transformers/convert_bert_original_tf_checkpoint_to_pytorch.py \ 6 | --tf_checkpoint_path='../bert-for-radiology/models/tf-bert-base-german-radiology-cased/pretraining_output/model.ckpt-100000' \ 7 | --bert_config_file='../bert-for-radiology/models/tf-bert-base-german-radiology-cased/rogerbert_config.json' \ 8 | --pytorch_dump_path='../bert-for-radiology/models/pt-bert-base-german-radiology-cased/pytorch-model' 9 | 10 | # the -00000-of-00001 in the name of the modelfile, model.ckpt-XXX.data should not be wirtten down when defining the tf_checkpoint_path. Otherwise it will throw the following error: 11 | # Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator? 12 | 13 | -------------------------------------------------------------------------------- /pretraining/run_datageneration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | printf "\033c" 3 | 4 | eval "$(conda shell.bash hook)" 5 | conda activate bert-vocab 6 | 7 | 8 | # define variables 9 | i=0 10 | datadir='../data/small-splits/' 11 | outdir128='../tmp/tf_examples.tfrecord-seq128-' 12 | outdir512='../tmp/tf_examples.tfrecord-seq512-' 13 | vocabdir='../wordpiece-vocabularies/vocab-bert-handcleaned-30000.txt' 14 | splits=$(ls $datadir) 15 | 16 | # run the loop 17 | for FILE in $splits 18 | do 19 | i=$((i+1)) 20 | 21 | python create_pretraining_data.py \ 22 | --input_file="$datadir$FILE" \ 23 | --output_file="$outdir128${i}" \ 24 | --vocab_file="$vocabdir" \ 25 | --do_lower_case=False \ 26 | --max_seq_length=128 \ 27 | --max_predictions_per_seq=20 \ 28 | --masked_lm_prob=0.15 \ 29 | --random_seed=12345 \ 30 | --dupe_factor=5 31 | 32 | python create_pretraining_data.py \ 33 | --input_file="$datadir$FILE" \ 34 | --output_file="$outdir512${i}" \ 35 | --vocab_file="$vocabdir" \ 36 | --do_lower_case=False \ 37 | --max_seq_length=512 \ 38 | --max_predictions_per_seq=20 \ 39 | --masked_lm_prob=0.15 \ 40 | --random_seed=12345 \ 41 | --dupe_factor=5 42 | printf "\033c" 43 | 44 | done 45 | 46 | -------------------------------------------------------------------------------- /wordpiece-vocabularies/insert-custom-vocabs.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | 3 | # this script can be used to insert new vocabs in an existing WordPiece vocabulary, in place of "unused" token 4 | # It is currently not used for models in this repository, as a new WordPiece vocabulary was generated. 5 | 6 | EXISTING_VOCABS = "vocab.txt" # path to the WordPiece vocab file to be changed 7 | VOCABS_TO_INSERT = "vocabs-to-insert.txt" # csv file or txt file with the vocabs to be inserted (one word each line) 8 | NEW_VOCAB = "my-vocab.txt" # path to write out the changes WordPiece voabulary 9 | 10 | vocabulary <- read_file(EXISTING_VOCABS) %>% str_split("\n") %>% unlist() %>% {.[1:30000]} 11 | 12 | add_vocab <- function(new_vocab, vocabulary) { 13 | 14 | if (!str_detect(paste(vocabulary, collapse = ""), new_vocab)) { 15 | 16 | unused <- str_detect(vocabulary, "unused") 17 | unused[2] <- FALSE # this likely indicates how many vocabs are not used 18 | index_unused <- (1:length(vocabulary))[unused] 19 | vocabulary[index_unused[1]] <- new_vocab 20 | vocabulary[index_unused[-1]] <- paste("[unused", 1:(length(index_unused)-1), "]", sep = "") 21 | vocabulary[2] <- paste("[unused", length(index_unused), "]", sep = "") 22 | 23 | } 24 | 25 | return(vocabulary) 26 | } 27 | 28 | new_vocabs <- read_file(VOCABS_TO_INSERT) %>% str_split("\n") %>% unlist() %>% {.[-length(.)]} 29 | for (i in new_vocabs) vocabulary <- add_vocab(i, vocabulary) 30 | 31 | write_file(paste(vocabulary, collapse = "\n"), NEW_VOCAB) 32 | -------------------------------------------------------------------------------- /evaluation/auc-for-ct.csv: -------------------------------------------------------------------------------- 1 | train_size,Model,Finding,AUPRC,AUC 2 | 4703,radbert,Congestion,0.77,0.9 3 | 4703,radbert,Opacity,0.96,0.81 4 | 4703,radbert,Effusion,0.93,0.9 5 | 4703,radbert,Pneumothorax,0.75,0.93 6 | 4703,radbert,Thoracic drain,0.88,0.93 7 | 4703,radbert,Venous Catheter,0.86,0.89 8 | 4703,radbert,Gastric Tube,0.71,0.84 9 | 4703,radbert,Tracheal Tube/Canula,0.9,0.91 10 | 4703,radbert,Misplaced Medical Device,0.4,0.77 11 | 4703,fsbert,Congestion,0.47,0.81 12 | 4703,fsbert,Opacity,0.74,0.6 13 | 4703,fsbert,Effusion,0.75,0.76 14 | 4703,fsbert,Pneumothorax,0.51,0.84 15 | 4703,fsbert,Thoracic drain,0.65,0.91 16 | 4703,fsbert,Venous Catheter,0.84,0.89 17 | 4703,fsbert,Gastric Tube,0.75,0.82 18 | 4703,fsbert,Tracheal Tube/Canula,0.87,0.9 19 | 4703,fsbert,Misplaced Medical Device,0,0.46 20 | 4703,gerbert,Congestion,0.56,0.76 21 | 4703,gerbert,Opacity,0.93,0.74 22 | 4703,gerbert,Effusion,0.96,0.94 23 | 4703,gerbert,Pneumothorax,0.69,0.78 24 | 4703,gerbert,Thoracic drain,0.76,0.91 25 | 4703,gerbert,Venous Catheter,0.77,0.87 26 | 4703,gerbert,Gastric Tube,0.6,0.74 27 | 4703,gerbert,Tracheal Tube/Canula,0.8,0.87 28 | 4703,gerbert,Misplaced Medical Device,0.62,0.97 29 | 4703,multibert,Congestion,0.65,0.87 30 | 4703,multibert,Opacity,0.93,0.75 31 | 4703,multibert,Effusion,0.97,0.95 32 | 4703,multibert,Pneumothorax,0.76,0.9 33 | 4703,multibert,Thoracic drain,0.79,0.89 34 | 4703,multibert,Venous Catheter,0.74,0.86 35 | 4703,multibert,Gastric Tube,0.61,0.76 36 | 4703,multibert,Tracheal Tube/Canula,0.82,0.89 37 | 4703,multibert,Misplaced Medical Device,0.5,0.96 38 | -------------------------------------------------------------------------------- /pretraining/run_pretraining.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ~/Documents/bert-for-radiology/pretraining 4 | 5 | # There have been some issues with the linebreaks. If they should persist, delete the '\' and write everything into one line, separating the commands by only one space. 6 | 7 | cd ~/Documents/bert-for-radiology/pretraining 8 | eval "$(conda shell.bash hook)" 9 | conda deactivate 10 | conda activate bert-pretraining 11 | 12 | # remove init_checkpoint to train from scratch 13 | 14 | python run_pretraining.py \ 15 | --input_file=../tmp/tf_examples.tfrecord-seq128-* \ 16 | --output_dir=../tmp/pretraining_output \ 17 | --do_train=True \ 18 | --do_eval=True \ 19 | --bert_config_file=../models/bert-base-german-cased/bert_config.json \ 20 | --init_checkpoint=../models/bert-base-german-cased/bert_model.ckpt \ 21 | --train_batch_size=32 \ 22 | --max_seq_length=128 \ 23 | --max_predictions_per_seq=20 \ 24 | --num_train_steps=90000 \ 25 | --num_warmup_steps=9000 \ 26 | --learning_rate=2e-5 27 | 28 | # to run additional 10000 steps, the overall num of train_steps needs to be set to 100000 (90000 + 10000 = 100000) 29 | 30 | python run_pretraining.py \ 31 | --input_file=../tmp/tf_examples.tfrecord-seq512-* \ 32 | --output_dir=../tmp/pretraining_output \ 33 | --do_train=True \ 34 | --do_eval=True \ 35 | --bert_config_file=../models/bert-base-german-cased/bert_config.json \ 36 | --init_checkpoint=../tmp/pretraining_output/model.ckpt-90000 \ 37 | --train_batch_size=6 \ 38 | --max_seq_length=512 \ 39 | --max_predictions_per_seq=20 \ 40 | --num_train_steps=100000 \ 41 | --num_warmup_steps=91000 \ 42 | --learning_rate=2e-5 43 | 44 | -------------------------------------------------------------------------------- /evaluation/table4.csv: -------------------------------------------------------------------------------- 1 | , ,rBERT, ,gerBERT, ,mBERT, ,sBERT, 2 | train size, finding, AUC,AUPRC,AUC,AUPRC,AUC,AUPRC,AUC,AUPRC 3 | 200,Congestion,0.86,0.48,0.82,0.26,0.86,0.47,0.85,0.41 4 | 200,Opacity,0.79,0.69,0.78,0.65,0.81,0.7,0.78,0.67 5 | 200,Effusion,0.9,0.82,0.86,0.77,0.88,0.78,0.86,0.75 6 | 200,Pneumothorax,0.94,0.27,0.94,0.21,0.95,0.4,0.95,0.37 7 | 200,Thoracic drain,0.97,0.88,0.88,0.48,0.93,0.73,0.87,0.42 8 | 200,Venous Catheter,0.97,0.97,0.96,0.95,0.96,0.96,0.93,0.92 9 | 200,Gastric Tube,0.98,0.95,0.97,0.89,0.98,0.93,0.97,0.92 10 | 200,Tracheal Tube/Canula,0.9,0.8,0.83,0.64,0.96,0.91,0.88,0.78 11 | 200,Misplaced Medical Device,0.95,0.04,0.95,0.01,0.95,0.04,0.95,0 12 | 1000,Congestion,0.97,0.88,0.96,0.86,0.97,0.9,0.89,0.6 13 | 1000,Opacity,0.94,0.91,0.92,0.87,0.94,0.92,0.87,0.82 14 | 1000,Effusion,0.97,0.95,0.95,0.91,0.97,0.94,0.95,0.92 15 | 1000,Pneumothorax,0.98,0.72,0.97,0.68,0.99,0.83,0.97,0.62 16 | 1000,Thoracic drain,0.98,0.94,0.99,0.97,0.98,0.92,0.97,0.88 17 | 1000,Venous Catheter,0.98,0.97,0.98,0.97,0.98,0.98,0.98,0.98 18 | 1000,Gastric Tube,0.99,0.96,0.99,0.97,0.99,0.98,0.99,0.95 19 | 1000,Tracheal Tube/Canula,0.99,0.98,0.99,0.97,0.99,0.97,0.99,0.97 20 | 1000,Misplaced Medical Device,0.97,0.46,0.96,0.25,0.98,0.54,0.95,0.01 21 | 4000,Congestion,0.98,0.92,0.98,0.92,0.98,0.92,0.97,0.87 22 | 4000,Opacity,0.97,0.94,0.96,0.94,0.95,0.93,0.95,0.93 23 | 4000,Effusion,0.97,0.94,0.97,0.94,0.98,0.95,0.97,0.95 24 | 4000,Pneumothorax,0.99,0.91,0.99,0.91,0.99,0.88,0.99,0.91 25 | 4000,Thoracic drain,0.99,0.95,0.99,0.95,0.99,0.95,0.99,0.94 26 | 4000,Venous Catheter,0.99,0.99,0.99,0.98,0.99,0.98,0.99,0.98 27 | 4000,Gastric Tube,0.99,0.98,1,0.98,0.99,0.97,0.99,0.97 28 | 4000,Tracheal Tube/Canula,0.99,0.98,0.99,0.98,0.99,0.98,0.99,0.98 29 | 4000,Misplaced Medical Device,0.99,0.73,0.99,0.76,0.99,0.78,0.98,0.64 30 | -------------------------------------------------------------------------------- /pretraining/run_sentencizing.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | from spacy.lang.de import German 3 | import pandas as pd 4 | import time 5 | 6 | nlp = German() 7 | nlp.add_pipe(nlp.create_pipe('sentencizer')) 8 | 9 | texts = pd.read_csv('../data/cleaned-text-dump.csv', low_memory=False) 10 | 11 | def sentencizer(raw_text, nlp): 12 | doc = nlp(raw_text) 13 | sentences = [sent.string.strip() for sent in doc.sents] 14 | return(sentences) 15 | 16 | def fix_wrong_splits(sentences): 17 | i=0 18 | 19 | while i < (len(sentences)-2): 20 | if sentences[i].endswith(('Z.n.','V.a.','v.a.', 'Vd.a.' 'i.v', ' re.', 21 | ' li.', 'und 4.', 'bds.', 'Bds.', 'Pat.', 22 | 'i.p.', 'i.P.', 'b.w.', 'i.e.L.', ' pect.', 23 | 'Ggfs.', 'ggf.', 'Ggf.', 'z.B.', 'a.e.' 24 | 'I.', 'II.', 'III.', 'IV.', 'V.', 'VI.', 'VII.', 25 | 'VIII.', 'IX.', 'X.', 'XI.', 'XII.')): 26 | sentences[i:i+2] = [' '.join(sentences[i:i+2])] 27 | 28 | elif len(sentences[i]) < 10: 29 | sentences[i:i+2] = [' '.join(sentences[i:i+2])] 30 | 31 | i+=1 32 | return(sentences) 33 | 34 | loggingstep = [] 35 | for i in range(1000): 36 | loggingstep.append(i*10000) 37 | 38 | 39 | tic = time.clock() 40 | for i in range(len(texts)): 41 | text = texts.TEXT[i] 42 | sentences = sentencizer(text, nlp) 43 | sentences = fix_wrong_splits(sentences) 44 | with open('../data/report-dump.txt', 'a+') as file: 45 | for sent in sentences: 46 | file.write(sent + '\n') 47 | file.write('\n') 48 | if i in loggingstep: 49 | toc = time.clock() 50 | print('dumped the ' + str(i) + "th report. " + str(toc - tic) + "seconds passed.") 51 | toc = time.clock() 52 | -------------------------------------------------------------------------------- /text-extraction/clean-report-texts.R: -------------------------------------------------------------------------------- 1 | # load packages 2 | library(magrittr) 3 | library(tidyverse) 4 | library(tictoc) 5 | 6 | # load rds dump 7 | path_to_rds_dump <- "/path/to/REPORT_TEXT_DUMP.RDS" 8 | TEXT_DUMP <- read_rds(path_to_rds_dump) %>% drop_na() # 33 sec 9 | 10 | # identify most frequent strings 11 | freqfunc <- function(x, n) unlist(x) %>% table() %>% sort() %>% tail(n) %>% names() 12 | 13 | x <- unlist(TEXT_DUMP$TEXT) 14 | x <- table(x) 15 | x <- sort(x) 16 | str_to_remove_full_match <- names(tail(x, 100)) 17 | str_to_remove_partial_match <- c("Konstanzprüfung", 18 | "RK Import und digitale Archivierung von Fremdaufnahmen im PACS ohne Befunderstellung", 19 | "Demonstration ohne Befunderstellung", 20 | "Befundung erfolgt über eine externe Datenbank.", 21 | "Patient nicht erschienen am", 22 | "Qualitätssicherung", 23 | "Teleradiologische Bildübertragung ohne Befunderstellung", 24 | "Tumorkonferenzbetreuung", 25 | "von Station abgesagt", 26 | "Demonstration ohne Befunderstellung", 27 | "Import und digitale Archivierung von Fremdaufnahmen", 28 | "Patient nicht erschienen am" 29 | ) 30 | 31 | 32 | #remove most frequent strings by full match 33 | #I use a loop, as it allows for better tracking 34 | for (i in str_to_remove_full_match) { 35 | print(paste("removing: ", i)) 36 | TEXT_DUMP %<>% filter(TEXT != i) 37 | } 38 | 39 | # remove strings by partial match 40 | # this is very unelegant tbh. It also takes very long but it works. 41 | for (i in str_to_remove_partial_match) { 42 | print(paste("removing: ", i)) 43 | print(nrow(TEXT_DUMP)) 44 | remove = lapply(TEXT_DUMP$TEXT, str_detect, i) %>% unlist 45 | if (!is.null(remove)) TEXT_DUMP <- TEXT_DUMP[!remove, ] 46 | } 47 | 48 | write_rds(TEXT_DUMP, "data/cleaned-text-dump.rds") 49 | write_csv(TEXT_DUMP, "data/cleaned-text-dump.csv") 50 | 51 | -------------------------------------------------------------------------------- /evaluation/table3.csv: -------------------------------------------------------------------------------- 1 | , ,rBERT, , ,gerBERT, , ,mBERT, , ,sBERT, , 2 | train size, finding, f1,j_stat,mcc,f1,j_stat,mcc,f1,j_stat,mcc,f1,j_stat,mcc 3 | 200,Congestion,0.28,0.32,0.21,0.06,0.13,0.05,0.4,0.41,0.31,0.37,0.35,0.27 4 | 200,Opacity,0.71,0.34,0.32,0.7,0.26,0.22,0.73,0.37,0.33,0.65,0.18,0.17 5 | 200,Effusion,0.72,0.57,0.55,0.72,0.47,0.47,0.75,0.55,0.55,0.65,0.42,0.41 6 | 200,Pneumothorax,0.05,0.92,0.15,0,0,0,0.14,0.68,0.23,0.1,0.59,0.17 7 | 200,Thoracic drain,0.77,0.78,0.72,0.29,0.46,0.27,0.58,0.68,0.54,0.23,0.27,0.17 8 | 200,Venous Catheter,0.92,0.83,0.82,0.93,0.87,0.84,0.92,0.82,0.8,0.86,0.66,0.65 9 | 200,Gastric Tube,0.9,0.89,0.87,0.81,0.8,0.75,0.89,0.87,0.85,0.81,0.82,0.76 10 | 200,Tracheal tube/canula,0.77,0.58,0.59,0.68,0.45,0.45,0.89,0.8,0.81,0.74,0.55,0.55 11 | 200,Misplaced Medical Device,0,0,0,0,0,0,0,0,0,0,0,0 12 | 1000,Congestion,0.85,0.81,0.81,0.86,0.82,0.82,0.86,0.82,0.82,0.53,0.48,0.42 13 | 1000,Opacity,0.88,0.76,0.72,0.87,0.71,0.7,0.88,0.74,0.73,0.79,0.54,0.51 14 | 1000,Effusion,0.94,0.89,0.89,0.86,0.75,0.75,0.92,0.84,0.85,0.88,0.79,0.79 15 | 1000,Pneumothorax,0.66,0.75,0.64,0.56,0.84,0.58,0.79,0.75,0.77,0.41,0.68,0.43 16 | 1000,Thoracic drain,0.88,0.83,0.85,0.89,0.85,0.86,0.9,0.84,0.87,0.83,0.81,0.78 17 | 1000,Venous Catheter,0.95,0.9,0.89,0.97,0.93,0.93,0.96,0.92,0.91,0.95,0.89,0.88 18 | 1000,Gastric Tube,0.96,0.94,0.95,0.97,0.95,0.95,0.96,0.93,0.95,0.94,0.9,0.91 19 | 1000,Tracheal tube/canula,0.97,0.95,0.95,0.97,0.94,0.94,0.96,0.92,0.93,0.95,0.9,0.91 20 | 1000,Misplaced Medical Device,0.45,0.97,0.53,0.15,0.62,0.22,0.4,0.96,0.49,0,0,0 21 | 4000,Congestion,0.9,0.85,0.87,0.9,0.85,0.87,0.89,0.83,0.86,0.86,0.81,0.82 22 | 4000,Opacity,0.92,0.85,0.84,0.92,0.84,0.83,0.91,0.81,0.8,0.89,0.78,0.77 23 | 4000,Effusion,0.92,0.85,0.85,0.93,0.87,0.88,0.94,0.88,0.88,0.92,0.86,0.86 24 | 4000,Pneumothorax,0.89,0.84,0.88,0.85,0.81,0.84,0.86,0.84,0.85,0.86,0.79,0.85 25 | 4000,Thoracic drain,0.95,0.91,0.93,0.95,0.92,0.94,0.95,0.91,0.93,0.91,0.88,0.89 26 | 4000,Venous Catheter,0.98,0.96,0.95,0.98,0.95,0.95,0.98,0.95,0.95,0.97,0.93,0.93 27 | 4000,Gastric Tube,0.97,0.95,0.96,0.98,0.96,0.97,0.97,0.95,0.96,0.96,0.94,0.95 28 | 4000,Tracheal tube/canula,0.97,0.95,0.96,0.98,0.96,0.97,0.98,0.96,0.96,0.98,0.95,0.96 29 | 4000,Misplaced Medical Device,0.59,0.82,0.61,0.76,0.87,0.76,0.76,0.87,0.76,0.33,0.79,0.4 30 | -------------------------------------------------------------------------------- /text-extraction/extract-reports.R: -------------------------------------------------------------------------------- 1 | library(lubridate) 2 | library(tidyverse) 3 | library(magrittr) 4 | 5 | ################################# 6 | # # 7 | # run time ~ 3 days # 8 | # # 9 | ################################# 10 | 11 | # set path, import table ---- 12 | REPORTS <- read_rds(paste("~/REPORT_TEXT_DUMP.RDS", sep = "")) 13 | # a very large table with two collumns: 14 | # FILE_ADDRESS_TEXT_REPORT --> Name of the plain text file with path 15 | # TEXT --> NA for default, report Texts will be placed here 16 | 17 | 18 | # define variables for monitoring ---- 19 | STEPS = 10000 20 | N_TO_EXTRACT = sum(is.na(REPORTS$TEXT)) 21 | SAVING_STEPS = (1:round(N_TO_EXTRACT/STEPS))*STEPS + nrow(REPORTS)-N_TO_EXTRACT 22 | # MESSAGE_STEPS = (1:round(N_TO_EXTRACT/(STEPS*10)))*(STEPS*10) + nrow(REPORTS)-N_TO_EXTRACT # uncomment if frequent status reports are desired 23 | START_TIME <- as.numeric(Sys.time()) 24 | START = Sys.time() 25 | EST_TIME = 0 26 | TRYS = 0 27 | LAST_SAVE = "NA" 28 | start_iter = sum(!is.na(REPORTS$TEXT)) 29 | 30 | # start the loop ---- 31 | # loops might be slower and a lot more code than a solution using purrr::map (one line of code) but it allows for better logging. 32 | 33 | for (i in 1:nrow(REPORTS)) { 34 | 35 | if (is.na(REPORTS$TEXT[i])) { 36 | 37 | file_adress <- REPORTS$FILE_ADDRESS_TEXT_REPORT[i] 38 | 39 | file <- try(read_file(file_adress)) 40 | 41 | if (class(file) == "try-error") {TRYS = TRYS + 1} else { 42 | file %<>% str_replace_all("\n", " ") %>% str_replace_all("\r", " ") %>% str_squish() 43 | REPORTS$TEXT[i] = file 44 | } 45 | 46 | if (i %in% SAVING_STEPS) { 47 | 48 | LAST_SAVE <- paste("Iteration:", i, "at:", Sys.time()) 49 | 50 | print("Saving") 51 | 52 | write_csv(REPORTS, "~data/REPORT_TEXT_DUMP.csv") 53 | write_rds(REPORTS, "~data/REPORT_TEXT_DUMP.RDS") 54 | 55 | print("Done") 56 | } 57 | 58 | # if (i %in% MESSAGE_STEPS) message_me(status_report, parse_mode = 'Markdown') # message_me = custom wrapper for telegram-bot 59 | 60 | if (i %% 100 == 0 ) { 61 | 62 | PASSED_TIME <- round(as.numeric(Sys.time()) - START_TIME) 63 | iterations_passed <- i - start_iter 64 | EST_TIME <- (PASSED_TIME/iterations_passed) * sum(is.na(REPORTS$TEXT)) 65 | 66 | status_report <- paste("*Status Update* \n\n", 67 | "Current time: ", Sys.time(),"\n", 68 | "Start time: ", START, "\n", 69 | "Passed time: ", seconds_to_period(PASSED_TIME),"\n", 70 | "Try-erros: ", TRYS, "\n", 71 | "Reports extracted: ", sum(!is.na(REPORTS$TEXT)), "\n", 72 | "Reports to extract: ", sum(is.na(REPORTS$TEXT)), "\n", 73 | "Last save: ", LAST_SAVE, "\n", 74 | "Est. remaining time: ", seconds_to_period(EST_TIME)) 75 | cat("\014") # flushes console 76 | cat(status_report) 77 | write_file(status_report, "data/status_report.md") 78 | 79 | } 80 | 81 | } 82 | } 83 | 84 | -------------------------------------------------------------------------------- /pretraining/notebooks/03_create-pretraining-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Create Pretraining Data\n", 15 | "Before starting the pretraining: \n", 16 | "\n", 17 | "1. Convert the text-reports to file-format, see [here](https://github.com/kbressem/bert-for-radiology/blob/master/pretraining/sentencizing.ipynb). \n", 18 | "2. Create the custom vocabulary file, see [here](https://github.com/kbressem/bert-for-radiology/blob/master/pretraining/bert-vocab-builder.ipynb). \n", 19 | "3. Create pretraining data with a [script](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) supplied by Google. \n", 20 | "\n", 21 | "The `create_pretraining_data.py` was copied to this directory, so it can be executed locally." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Initalizing the enviroment\n", 29 | "\n", 30 | "The BERT vocabulary Anaconda enviroment from the [bert-custom-vocabulary.ipyn](https://github.com/kbressem/bert-for-radiology/blob/master/pretraining/bert-vocab-builder.ipynb) will be used. \n", 31 | "\n", 32 | "The 3.6 GB grew to over 100 GB in memory, because of which the data needed to be split using the `split` bash command. To split each 1,000,000 lines, the following code was used:\n", 33 | "\n", 34 | "```bash\n", 35 | " split -l 1000000 report-dump.txt report-data-chunk-\n", 36 | "```\n", 37 | "\n", 38 | "The output are multiple 1,000,000-line files named: `report-data-chunk-aa`, `report-data-chunk-ab`, ... , `report-data-chunk-zz` \n", 39 | "In our case, it resulted in 55 files. \n", 40 | "\n", 41 | "In order to create the pretraining data from the first split, one must execute the Google script with the following command. \n", 42 | "\n", 43 | "```bash\n", 44 | "python create_pretraining_data.py \\\n", 45 | " --input_file='../data/small-splits/report-data-chunk-aa' \\\n", 46 | " --output_file='../tmp/tf_examples.tfrecord' \\\n", 47 | " --vocab_file='../vocab.txt' \\\n", 48 | " --do_lower_case=False \\\n", 49 | " --max_seq_length=128 \\\n", 50 | " --max_predictions_per_seq=20 \\\n", 51 | " --masked_lm_prob=0.15 \\\n", 52 | " --random_seed=12345 \\\n", 53 | " --dupe_factor=5\n", 54 | " ```\n", 55 | " \n", 56 | "To advoid copy-pasting the code multiple times for all data splits, we used a bash file which runs a loop over all data files stored in `../data/small-splits/` and stores the output in `../tmp`. \n", 57 | "The script also contains a logger, telling which file is currently processed and how many more remain." 58 | ] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "Python 3", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.7.4" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 4 82 | } 83 | -------------------------------------------------------------------------------- /pretraining/bert-vocab-builder/subword_builder.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | # Copyright 2018 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Program to build a SubwordTextEncoder. 17 | 18 | The flags --min_count and --corpus_max_lines will affect the size of the 19 | vocabulary. Try changing these flags until you get a vocabulary 20 | of the size you want. 21 | 22 | Example usage: 23 | 24 | python data_generators/text_encoder_build_subword.py \ 25 | --corpus_filepattern=$DATA_DIR/my_problem-train-* \ 26 | --corpus_max_lines=12345 \ 27 | --output_filename=$DATA_DIR/my_problem.subword_text_encoder \ 28 | --logtostderr 29 | 30 | """ 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | import text_encoder 34 | import tokenizer 35 | 36 | import tensorflow as tf 37 | 38 | tf.flags.DEFINE_string('output_filename', '/tmp/my.subword_text_encoder', 39 | 'where to store the SubwordTextEncoder') 40 | tf.flags.DEFINE_string('corpus_filepattern', '', 41 | 'Corpus of one or more text files') 42 | tf.flags.DEFINE_string('vocab_filepattern', '', 'One or more vocabulary files ' 43 | '(one word per line as "word,count")') 44 | tf.flags.DEFINE_integer('min_count', 5, 'Minimum subtoken count in corpus') 45 | tf.flags.DEFINE_integer('corpus_max_lines', None, 46 | 'How many lines of corpus to read') 47 | tf.flags.DEFINE_integer('num_iterations', 5, 'Number of iterations') 48 | tf.flags.DEFINE_bool('split_on_newlines', True, 'Break corpus into lines.') 49 | tf.flags.DEFINE_string('additional_chars', "", 'Set special characters to be included in vocab. ex : "~", "/".') 50 | tf.flags.DEFINE_integer('max_subtoken_length', None, 'Max subtoken length') 51 | FLAGS = tf.flags.FLAGS 52 | 53 | 54 | def main(unused_argv): 55 | if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: 56 | raise ValueError( 57 | 'Must only provide one of --corpus_filepattern or --vocab_filepattern') 58 | 59 | elif FLAGS.corpus_filepattern: 60 | token_counts = tokenizer.corpus_token_counts( 61 | FLAGS.corpus_filepattern, 62 | FLAGS.corpus_max_lines, 63 | split_on_newlines=FLAGS.split_on_newlines, additional_chars=FLAGS.additional_chars) 64 | 65 | elif FLAGS.vocab_filepattern: 66 | token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern, 67 | FLAGS.corpus_max_lines) 68 | 69 | else: 70 | raise ValueError( 71 | 'Must provide one of --corpus_filepattern or --vocab_filepattern') 72 | 73 | encoder = text_encoder.SubwordTextEncoder() 74 | encoder.build_from_token_counts(token_counts, FLAGS.min_count, 75 | FLAGS.num_iterations, max_subtoken_length=FLAGS.max_subtoken_length) 76 | encoder.store_to_file(FLAGS.output_filename, add_single_quotes=False) 77 | # encoder.store_to_file_with_counts(FLAGS.output_filename + "_counts") 78 | 79 | 80 | if __name__ == '__main__': 81 | tf.app.run() 82 | -------------------------------------------------------------------------------- /annotator/README.md: -------------------------------------------------------------------------------- 1 | # A Simple Tool for the Classification of Radiological Text Data 2 | 3 | This terminal program should help to quickly and accurately classify diagnostic texts, e.g. for thoracic images. At the moment, up to nine different findings can be classified. The names of the findings (e.g. effusion, pneumothorax) are directly imported from the csv-file containing the file-names. 4 | 5 | ## Requirements 6 | Python 3.6 7 | Numpy 1.16.4 8 | Pandas 0.24.2 9 | Curses (pre-installed in Python on Ubuntu). 10 | Unix operating system (Windows not yet supported). 11 | 12 | ## Setting up the programm 13 | Clone or download this repository. Paste a csv-file (structure of the file described below) into the `data` folder of the project. Rename the csv-file to `file_dir.csv`. Open the main folder in a new terminal window. In Ubuntu, this can be done by a right click on the folder and by selecting `open in terminal`. Type `python3 simple-annotator.py` into your terminal. The terminal should be maximized in order to achieve an optimal display of the user interface and report texts. 14 | [BUG] Changing the size of the terminal while the annotator is running currently causes the program to crash. 15 | 16 | ## How it works 17 | 18 | The basis is a csv file, e.g. with the following structure: 19 | 20 | 21 | filename | Congestion | Pneumonia | Effusion | Pneumothorax | Foreign Materials 22 | ------------------------------------------------------------------------------------- 23 | 13764005.txt| 1 | 1 | 0 | 0 | 0 24 | 13374201.txt| 1 | 0 | 0 | 0 | 1 25 | 13740269.txt| -1 | -1 | -1 | -1 | -1 26 | 12345678.txt| 0 | 0 | 1 | 1 | 1 27 | 10420000.txt| | | | | 28 | 29 | 1: finding is present 30 | 0: finding is not present 31 | -1: unevaluable (always whole row) 32 | : no annotation (NA) 33 | 34 | Each annotation is immediately stored in the csv-file. 35 | 36 | ## Using the program 37 | After the start, you will be asked to enter your name. This allows you to follow the comments of different people. Please note, however, that currently no more than one annotation per file is possible, and previous annotations from another user (or the same user) will be overwritten. 38 | ![Screenshot of the Application](screenshot_application.png) 39 | 40 | ### Annotation 41 | The annotations are displayed here. Magenta text means that there is no annotation yet (values are NA). After pressing a key for annotation (e.g. 1 for congestion), the corresponding line is highlighted (black writing on white background). White writing on a black background corresponds to the absence of the findings. 42 | 43 | There is also the possibility to rate one's confidence in the annotation between "low", "medium" and "high". The confidence rating is storend in a file called `data/log.txt`, next to the name of the annotator and the file name. The file is created at program start, if it does not already exist. 44 | 45 | ### Controls 46 | Explanation of the controls. The letter in front of the closed bracket corresponds to the key to be pressed. 47 | 48 | ### Options 49 | Further options. You can choose whether you want to display only files that have already been commented or only files that have not yet been commented. There is also a possibility to turn on an AI helper for annotation, which, however, is not yet implemented. 50 | 51 | ### Progress bar 52 | Some measurements of progress are given above the text reports. The flashing hash (`#`) indicates your current position in the loaded stack of text reports. A small "x" indicates that the report is annotated but the condifence rating is missing, a capital "X" indicates a full annotation including a confidence rating. 53 | 54 | ### Report text 55 | Text reports of radiological findings are displayed here. You can switch between the previous or next text with the arrow keys or `n` or `v`. 56 | -------------------------------------------------------------------------------- /rule-based-classification/negative_congestion.csv: -------------------------------------------------------------------------------- 1 | keine stauungszeichen 2 | keine zeichen einer akuten kardialen dekompensation 3 | keine progrediente volumenbelastung 4 | keine relevante zentrale stauung 5 | keine zeichen akuter zentraler stauung 6 | keine höhergradigen zentralen stauungszeichen 7 | keine pulmonale stauung 8 | keine stauung 9 | lgz zentral gering vermehrt 10 | keine höhergradigen flüssigkeitseinlagerungen 11 | keine höhergradigen stauungszeichen 12 | keine anzeichen einer dekompensation 13 | keine akuten zentralen stauungszeichen 14 | keine relevante zentrale stauung 15 | keine höhergradige lungenstauung 16 | keine wesentliche zentrale volumenbelastung 17 | zentrale stauungszeichen vorhanden 18 | keine höhergradige zentrale stauung 19 | lungengefäßzeichnung gestaut 20 | lungengefäßzeichnung unauffällig 21 | stauung bei überlagerung aktuell nicht sicher beurteilbar 22 | keine höhergradige zentrale volumenbelastung 23 | keine höhergradigen pulmonalvenösen stauungszeichen 24 | kein nachweis konfluierender pneumonische infiltrat oder höhergradige stauungszeichen 25 | keine pulmonalvenösen stauungszeichen 26 | kräftige lungengefäßzeichnung ohne höhergradige flüssigkeitseinlagerungen 27 | keine zentrale stauung 28 | ohne wesentliche zentrale stauung 29 | ohne wesentliche dekompensationszeichen 30 | keine akuten stauungszeichen 31 | lgz zentral mäßig vermehrt 32 | ohne dekompensationszeichen 33 | keine kardiopulmonalen stauungszeichen 34 | ohne akute dekompensation 35 | keine wesentlichen stauungszeichen 36 | ohne höhergradige stauung 37 | keine höhergradige stauung 38 | pulmonale stauung nicht abgrenzbar 39 | keine wesentliche zentrale stauung 40 | keine höhergradigen zeichen einer zentralen volumenbelastung 41 | keine stauung 42 | keine pulmonalvenöse stauung 43 | keine pulmonalvenöse stauung 44 | keine pulmonale stauung 45 | ohne zeichen akuter flüssigkeitseinlagerungen 46 | in erster linie orthostatisch bedingt bei aufnahme im liegen 47 | keine höhergradige pulmonalvenöse stauung 48 | keine gravierende pulmonale stauung 49 | keine höhergradige kardiopulmonale stauung 50 | ohne zentrale stauung 51 | herz nicht akut gestaut 52 | keine wesentliche pulmonale stauung 53 | zentral betonte pulmonalvenöse stauung 54 | nicht höhergradig 55 | keine höhergradige zentrale volumenbelastung 56 | keine zeichen einer pulmonalvenösen stauung 57 | keine kardiopulmonalen stauungszeichen 58 | keine wesentliche zentrale volumenbelastung 59 | keine relevante pulmonale stauung 60 | keine akute dekompensation 61 | keine pulmonalen stauungszeichen 62 | keine kardiopulmonale stauung 63 | keine höhergradige zentrale stauung 64 | keine wesentliche stauung 65 | allenfalls diskrete stauung 66 | keine relevanten stauungszeichen 67 | keine akuten dekompensationszeichen 68 | keine höhergradige pulmonale stauung 69 | keine höhergradige volumenbelastung 70 | keine zeichen einer kardialen dekompensation 71 | kräftige lungengefäßzeichnung 72 | unauffällige zentrale und periphere lungengefäßzeichnung 73 | keine zeichen höhergradiger akuter zentraler stauung 74 | regredienz der zentralen stauung 75 | keine akute volumenbelastung 76 | keine rel. zentralvenöse stauung 77 | kein anhalt für kardiale dekompensation 78 | keine volumenbelastung 79 | minimale stauungszeichen 80 | keine zeichen einer relevanten zentralen volumenbelastung 81 | keine wesentliche flüssigkeitseinlagerung 82 | keine flüssigkeitseinlagerungen 83 | keine zeichen der zentralen volumenbelastung 84 | unauffällige lungengefäßzeichnung 85 | keine dekompensation 86 | nicht gestaut 87 | volumenstatus nicht sicher beurteilbar 88 | kein nachweis signifikanter zentraler stauungszeichen 89 | akzentuierte zentrale lgz 90 | cave stauung 91 | keine gravierenden pulmonalen stauungszeichen 92 | keine relevante stauung 93 | keine kardialen dekompensationszeichen 94 | keine c-p stauung 95 | kein hinweis auf eine kardiale dekompensation 96 | keine akuten kardiopulmonalen stauungszeichen 97 | keine manifeste kardiale dekompensation 98 | keine flüssigkeitseinlagerung 99 | keine zeichen der höhergradigen pulmonalen stauung 100 | keine rel. stauung 101 | keine kardiale dekompensation 102 | keine kardiopulmonalen stauung 103 | lungengefäßzeichnung zentral vermehrt 104 | kein nachweis eines signifikanten zentralen stauung 105 | keien stauung 106 | keine gravierende stauung 107 | volumenbelastung aktuell rückläufig 108 | keine relevante kardiale stauung 109 | keine wesentliche volumenbelastung 110 | keine relevanten flüssigkeitseinlagerungen 111 | fraglich zentrale flüssigkeitseinlagerungen 112 | keine höhergradigen flüssigkeitseinlagerung 113 | keine höhergradige cardiale volumenbelastung 114 | keine höhergradige pulmonalve-nöse stauung 115 | ohne anhalt für akute dekompensation 116 | kein höhergradige stauung 117 | keine akuten kardialen dekompensation 118 | kein nachweis einer signifikanten zentralen stauung 119 | ohne wesentliche stauung 120 | keine wesentliche dekompensation 121 | keine zeichen einer höhergradigen pulmonalvenösen stauung 122 | keine zeichen der kardialen dekompensation 123 | allenfalls geringgradige flüssigkeitseinlagerungen 124 | keine zentrale pulmonal-venöse stauung 125 | keine zeichen höhergradiger akuter stauung 126 | moderate volumenbelastung im liegen 127 | ohne stauung 128 | ohne zeichen der akuten kardialen dekompensation 129 | keine signifikante stauung 130 | keine zeichen der zentralvenösen stauung 131 | -------------------------------------------------------------------------------- /pretraining/notebooks/02_bert-custom-vocabulary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Building a custom vocabulary for BERT\n", 8 | "https://github.com/kwonmha/bert-vocab-builder\n", 9 | "\n", 10 | "In order to achieve good results on domain-specific datasets, BERT has to be pre-trained to enable a better understanding. Having a look at the vocabulary from the German BERT-Base model by deepset.ai, there are 'only' 30000 vocabulary words (of which 3001 are unused), while some of the most frequent vocabulary from medical texts is absent. For example:\n", 11 | "\n", 12 | "| German Word | English Translation |\n", 13 | "|-------------|----------------------|\n", 14 | "|Pneumothorax | pneumothorax |\n", 15 | "|Erguss | effusion |\n", 16 | "|Infiltrat | infiltrate |\n", 17 | "|Dystelektase | dystelectasis |\n", 18 | "| ... | ... |\n", 19 | "\n", 20 | "\n", 21 | "Google's research does not provide tools to create a custom vocabulary, however [this](https://github.com/kwonmha/bert-vocab-builder) Github repository of [kwonmha](https://github.com/kwonmha) does. In order to use the scripts, they have been downloaded into the folder `bert-vocab-builder`. \n", 22 | "\n", 23 | "The vocabulary can be build via the following bash-command: \n", 24 | "\n", 25 | "```bash\n", 26 | "python subword_builder.py \\\n", 27 | "--corpus_filepattern {corpus_for_vocab} \\\n", 28 | "--output_filename {name_of_vocab}\n", 29 | "--min_count {minimum_subtoken_counts}\n", 30 | "```\n", 31 | "\n", 32 | "To define a reasonable mininum subtoken count, we proceeded as follows: \n", 33 | "In a [previous notebook](https://github.com/kbressem/bert-for-radiology/blob/master/pretraining/sentencizing.ipynb), the word frequency was counted in all text-reports and then put into a .json file. This shows the frequency of specific words, enabling the definition of a reasonable threshold. " 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Initializing the enviroment\n", 41 | "\n", 42 | "```bash\n", 43 | "conda create --name=bert-vocab tensorflow\n", 44 | "conda activate bert-vocab\n", 45 | "conda install ipykernel spacy\n", 46 | "ipython kernel install --user --name=bert-vocab\n", 47 | "```" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Importing the .json file\n", 55 | "Since we work with very sensible data, neither the original text nor the .json file can be uploaded, as a small risk remains that a patient name could be mentioned somewhere in a report text. " 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 1, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import json\n", 65 | "from collections import OrderedDict\n", 66 | "\n", 67 | "with open('../data/word-count-report-dump.json') as json_file:\n", 68 | " wordcount = json.load(json_file)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 6, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def sortSecond(val): \n", 78 | " return val[1] \n", 79 | "\n", 80 | "wordcount['__individual count__'].sort(key = sortSecond, reverse = True)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 7, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "GREATER_THAN = 1000\n", 90 | "\n", 91 | "wordcount_greater = []\n", 92 | "for i in wordcount['__individual count__']:\n", 93 | " if i[1] > GREATER_THAN:\n", 94 | " wordcount_greater.append(i)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 8, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/plain": [ 105 | "23783" 106 | ] 107 | }, 108 | "execution_count": 8, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "wordcount_greater.sort(key = sortSecond, reverse = False)\n", 115 | "len(wordcount_greater)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "I would suggest to set `--min_count` to 5000. " 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Generation of a custom vocabulary\n", 130 | "\n", 131 | "Installing spaCy and tensorflow automaticall downgrades tensorflow to version 1.13.1. Although the code to create the custom vocabulary is based on tensorflow 1.11, it currently works: \n", 132 | "\n", 133 | "```bash\n", 134 | "python subword_builder.py \\\n", 135 | " --corpus_filepattern '../../data/report-dump.raw' \\\n", 136 | " --output_filename '../../pretraining/vocab-bert.txt' \\\n", 137 | " --min_count 5000\n", 138 | "```" 139 | ] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "bert-vocab", 145 | "language": "python", 146 | "name": "bert-vocab" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.7.5" 159 | } 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 4 163 | } 164 | -------------------------------------------------------------------------------- /annotator/simple-annotator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from os import chdir 4 | import os.path 5 | 6 | import pandas as pd 7 | import curses 8 | import numpy as np 9 | from helpers import is_annotated, Finding, RightSideMenu, MasterTable, fish, ProgressBar, User, Confidence 10 | 11 | current_annotator = User() 12 | 13 | 14 | class SimpleAnnotaor(object): 15 | def __init__(self): 16 | self.foo = "bar" 17 | 18 | def check_untypical_length(self, text): 19 | if len(text) < 250 or len(text) > 2000: 20 | return True 21 | else: 22 | return False 23 | 24 | 25 | simple_annotator = SimpleAnnotaor() 26 | 27 | 28 | def scroll_reports(key, row, n_row): 29 | if key == ord('n') or key == curses.KEY_RIGHT: 30 | if row < (n_row - 1): 31 | row += 1 32 | else: 33 | if row > 0: 34 | row -= 1 35 | return row 36 | 37 | 38 | def finding_update(master_table, FindingCreator=Finding): 39 | findings = [] 40 | for col in range(1, master_table.n_col): 41 | findings.append(FindingCreator(master_table.csv_file, col, master_table.row)) 42 | 43 | return findings 44 | 45 | 46 | def main(stdscr): 47 | curses.resize_term(100, 300) # prevents crash from resizing if lower right panel border is not active terminal 48 | # size) 49 | 50 | curses.curs_set(0) 51 | stdscr.idcok(False) # might reduce flashing 52 | stdscr.idlok(False) # might reduce flashing 53 | 54 | # set color pairs 55 | curses.init_pair(1, curses.COLOR_RED, curses.COLOR_CYAN) # for non-evaluable findings 56 | curses.init_pair(2, curses.COLOR_MAGENTA, curses.COLOR_BLACK) # for NaN values 57 | curses.init_pair(3, curses.COLOR_WHITE, curses.COLOR_RED) # low confidence 58 | curses.init_pair(4, curses.COLOR_WHITE, curses.COLOR_YELLOW) # medium confidence 59 | curses.init_pair(5, curses.COLOR_WHITE, curses.COLOR_GREEN) # high confidence 60 | 61 | 62 | # create new windows for menu, report, progressbar etc. 63 | menu_window = curses.newpad(200, 50) 64 | text_report_window = curses.newpad(400, 400) 65 | 66 | # init variables 67 | show_annotated = 0 68 | key_range = [] 69 | 70 | # init classes 71 | # csv file with annotations, need to find a better name 72 | master_table = MasterTable('data/file_dir.csv') 73 | master_table.read_file() 74 | 75 | # print report text 76 | master_table.print_report(text_report_window) 77 | 78 | # Confidence ratings 79 | confidence = Confidence() 80 | 81 | # Progress Bar on top 82 | progress_bar = ProgressBar(master_table.n_row, current_annotator.name, master_table.csv_file) 83 | progress_bar.update(master_table.csv_file.iloc[:, 1].count(), master_table.row) 84 | 85 | # Right side Menu 86 | right_side_menu = RightSideMenu() 87 | right_side_menu.show_menu(menu_window, confidence.level) 88 | findings = finding_update(master_table) 89 | 90 | if len(findings) > 9: # currently no more than 9 findings are supported (key 1 to 9) 91 | findings = findings[0:9] 92 | master_table.n_col = 10 93 | 94 | for item in findings: 95 | key_range.append(item.key) 96 | 97 | while 1: 98 | for item in findings: 99 | item.print_label(menu_window, ") ") 100 | 101 | key = stdscr.getch() 102 | 103 | if key in key_range: 104 | for finding in findings: 105 | if finding.key == key: 106 | if finding.value != -1: 107 | finding.toggle() 108 | elif key == ord('a'): 109 | for finding in findings: 110 | finding.value = 1 111 | elif key == ord('s'): 112 | for finding in findings: 113 | finding.value = 0 114 | elif key == ord('d'): 115 | for finding in findings: 116 | finding.value = np.NaN 117 | elif key == ord('x'): 118 | for finding in findings: 119 | finding.toggle_uninterpretable() 120 | 121 | elif key == ord('n') or key == curses.KEY_RIGHT or key == ord('v') or key == curses.KEY_LEFT: 122 | while 1: 123 | master_table.row = scroll_reports(key, master_table.row, master_table.n_row) 124 | findings = finding_update(master_table) 125 | confidence.load_existing_level(master_table.csv_file, master_table.row) 126 | if show_annotated == 1 and is_annotated(findings): 127 | break 128 | elif show_annotated == 2 and not is_annotated(findings): 129 | break 130 | elif show_annotated == 0: 131 | break 132 | elif master_table.row in [0, master_table.n_row-1]: 133 | break 134 | 135 | elif key == ord('A'): 136 | right_side_menu.show_only_annotated_toggle() 137 | if show_annotated in [0, 1]: 138 | show_annotated += 1 139 | else: 140 | show_annotated = 0 141 | elif key == ord('S'): 142 | right_side_menu.toggle_ai() 143 | elif key == curses.KEY_UP and not np.isnan(findings[1].value): 144 | confidence.toggle("up") 145 | elif key == curses.KEY_DOWN and not np.isnan(findings[1].value): 146 | confidence.toggle("down") 147 | elif key == ord('q'): 148 | curses.endwin() 149 | quit() 150 | elif key == ord('f'): 151 | fish() 152 | elif key == curses.KEY_MOUSE or key == curses.KEY_MOVE: 153 | curses.resize_term(100, 300) 154 | curses.flushinp() # prevents endless flicker after mouse scrolling 155 | else: 156 | curses.flushinp() 157 | curses.resize_term(100, 300) 158 | 159 | if is_annotated(findings): 160 | for finding in findings: 161 | finding.NaN_to_0() 162 | else: 163 | confidence.toggle("delete") 164 | 165 | 166 | master_table.write_findings(findings) 167 | master_table.read_file() 168 | master_table.print_report(text_report_window) 169 | progress_bar.update(master_table.csv_file.iloc[:, 1].count(), master_table.row) 170 | current_annotator.update_log(master_table.file_name, is_annotated(findings), confidence.level) 171 | right_side_menu.show_menu(menu_window, confidence.level) 172 | 173 | 174 | curses.wrapper(main) 175 | -------------------------------------------------------------------------------- /rule-based-classification/positive_congestion.csv: -------------------------------------------------------------------------------- 1 | stauungszeichen regredient 2 | geringe zentral betonte volumenbelastungzeichen 3 | moderate zentrale stauung 4 | geringe pulmonalvenöse stauungszeichen 5 | progrediente ubiquitäre zeichnungsvermehrung 6 | mit regredienten stauungszeichen 7 | progrediente zentrale stauung 8 | geringe pulmonale stauung 9 | geringgradige lungenstauung 10 | geringe zentralvenöse stauung 11 | pulmonalvenöse stauung 12 | lungengefäßzeichnung gestaut 13 | zunehmende kardiopulmonale volumenbelastung 14 | bekannte zentrale stauungszeichen 15 | bild eines lungenödems 16 | lungengefäßzeichnung im verlauf regredient gestaut 17 | geringgradigen pulmonalvenösen stauung 18 | regrediente zentrale volumsbelastung 19 | zentrale volumenbelastung 20 | leichte zentrale stauungszeichen 21 | zentral betonte pulmonale stauung 22 | diskrete stauungszeichen 23 | flüssigkeitseinlagerungen im lungenkern 24 | infiltraten und stauungskomponente 25 | zentrale stauungskomponente 26 | mäßige stauungzeichen 27 | zeichen der zentralen stauung 28 | zunehmend gestaut 29 | unveränderte volumenbelastung 30 | geringe stauung 31 | diskrete pulmonale stauung 32 | pulmonalvenen gestaut 33 | akute stauung mit zentralen flüssigkeitseinlagerungen 34 | zunehmendes volumen 35 | pulmonale stauungszeichen. 36 | geringe pulmonalvenöse stauung 37 | noch mäßige zentral betonte pulmonale stauung 38 | im kurzfristigen verlauf regrediente pulmonalvenöse stauung 39 | geringe zentrale lungenstauung 40 | zunehmend zentrale volumenbelastung 41 | konstant diskrete kardiopulmonale volumenbelastung 42 | zunehmende pulmonalvenöse stauungszeichen 43 | mäßiges volumen 44 | zunehmende stauungszeichen 45 | lungengefäßzeichnung vereinbar mit kardiopulmonaler stauung 46 | mäßige zentrale volumenbelastung 47 | unehmende kardiopulmonale stauung 48 | eine zusätzliche stauungskomponente kann bildmorphologisch nicht sicher ausgeschlossen werden 49 | regrediente volumenbelastung 50 | geringe zentrale stauung 51 | begleitende stauungskomponente 52 | regrediente zentrale stauung 53 | geringe zentralvenöse stauung 54 | geringe zentralvenöse stauung 55 | zentrale stauungszeichen 56 | mäßigen zentralen stauungszeichen 57 | zunehmende zentrale volumenbelastung 58 | stauung leicht progredient 59 | interstitielle stauungskomponente 60 | dezente volumenbelastung 61 | regrediente stauung 62 | zentrale stauung 63 | bei geringgradiger lungenstauung 64 | capillary leak 65 | lgz zentral gering betont 66 | geringe stauung 67 | mäßige stauung 68 | deutlicher zentraler stauung 69 | mäßiggradigen stauungskomponente 70 | geringe zentrale volumenbelastung 71 | zentral betonte 72 | pulmonalvenöse stauung 73 | zentral betonte 74 | pulmonalvenöse stauung 75 | deutliche zentrale stauung 76 | gering zeichen der pulmonalvenösen stauung 77 | ödem 78 | a.e. flüssigkeitseinlagerungen 79 | zunehmend zentrale stauung 80 | geringer zentraler stauungsaspekt 81 | zunehmende volumenbelastung 82 | zeichen der zentral betonten pulmonalvenösen stauung 83 | zentrale stauung etwas progredient 84 | gefäßzeichnung gering vermehrt 85 | im sinne geringer stauung 86 | stauungszeichen geringen ausmaßes 87 | keine zentralen stauungszeichen 88 | zunehmende pulmonale stauung 89 | zunehmende grobfleckige infiltrate beidseits 90 | geringgradig regredienten zeichen der volumenbelastung 91 | a.e. stauungsbedingt 92 | mäßige zentrale stauung 93 | zunehmende zentrale stauung 94 | stauungszeichen im kleinen kreislauf 95 | stauungszeichen zunehmend 96 | geringgradigen stauung 97 | zentrale stauung 98 | zunehmende zeichen der volumenbelastung 99 | deutliche zeichen der pulmonalvenösen stauung 100 | lgz mäßig vermehrt mit beginnender flüssigkeitseinlagerungen 101 | mäßigen pulmonalen stauung 102 | zentrale pulmonale stauung 103 | vermehrte zentrale lungengefäßzeichnung 104 | akute stauung 105 | progrediente volumenbelastung 106 | zeichen der chronischen stauung 107 | deutliche flüssigkeitseinlagerungen 108 | zentral betonte stauungszeichen 109 | kardiopulmonale stauung 110 | betonte stauungszeichen 111 | weiterhin volumen 112 | gering pulmonaler stauung 113 | stauungszeichen eher regredient 114 | stauungszeichen eher regredient 115 | erster linie rahmen einer stauung 116 | deutliche pulmonale stauung 117 | zunehmende stauung 118 | zentralen volumenbelastung 119 | interstitielle flüssigkeitseinlagerng 120 | mäßiggradige flüssigkeitseinlagerungen 121 | stauungszeichen 122 | zentralen pulmonalen stauung 123 | zeichen der pulmonalvenösen stauung 124 | zentral betonte stauung 125 | pulmonale stauung 126 | zentralbetonte stauung 127 | zeichen einer pulmonalvenösen stauung 128 | pulmonalvenöser stauung 129 | stauungspneumonie 130 | zunehmender volumenbelastung 131 | progrediente zentralvenöse stauung 132 | geringe volumenbelastung 133 | kombination aus stauung 134 | keine progrediente stauung 135 | zunehmenden stauung 136 | progredienter pulmonalen stauung 137 | diskreter zentraler stauung 138 | chronischen kardiopulmonalen stauung 139 | geringgradige stauung 140 | die lungengefäße sind mäßigen grades gestaut 141 | stauungsinfiltrate 142 | mäßige volumenbelastung 143 | zeichen der zentralen und peripheren pulmonalen stauung 144 | in erster linie stauungsbedingt 145 | progredient gestaut 146 | progredienten stauung 147 | regredienz der pulmonalvenösen stauung 148 | i.1.l stauung. 149 | im sinne einer kardiopulmonalen stauung 150 | zeichen der gering- bis mäßiggradigen stauung 151 | deutliche zeichen der pulmonalen stauung 152 | in erster linie volumenbelastung 153 | zunahme der stauung 154 | mit zentraler und peripherer zentralvenöser stauung 155 | regredient gestaut 156 | mäßiger kardiopulmonaler stauung 157 | geringe flüssigkeitseinlagerungen 158 | progredienter stauung 159 | abnehmende volumenbelastung 160 | zeichen zentraler stauung 161 | zeichen pulmonalvenösen stauung 162 | flüssigkeitseinlagerungen bds. 163 | konstante stauungskomponente 164 | regrediente zentralvenöse stauung 165 | regrediente flüssigkeitseinlagerungen 166 | a.e. stauung 167 | zentralbetonte pulmonale volumenbelastung 168 | lungengefäßzeichnung mäßiggradig gestaut 169 | geringe lungenstauung 170 | progrediente kardiale dekompensationszeichen 171 | zentraler flüssigkeitseinlagerung 172 | gering gestaut 173 | lungengefäße zentral gering gestaut 174 | keine wesentl. stauung 175 | peripheren volumenbelastung 176 | volumenbelastung eher regredient 177 | volumenbelastung 178 | mäßiggradige pulmonal - venöse stauung 179 | interstitielle flüssigkeitseinlagerungen 180 | keine zeichen einer höhergradigen kardialen dekompensation 181 | bild vereinbar mit kardialer dekompensation 182 | periphere flüssigkeitseinlagerungen 183 | Stauung zentral 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classification of Radiological Text Reports using BERT 2 | 3 | ## Data Preparation 4 | ### Extraction of Free Text Reports 5 | Single plain text files were stored on a network-drive. File-names and paths were prior extracted using R `list.files()` function and stored in one very large table. As the workstation used by us has >120 GB RAM, keeping such large files in memory is no problem. On Coimputers with smaller memory some workarrounds might be needed. 6 | 7 | used script: [text-extraction/extract-reports.R](text-extraction/extract-reports.R) 8 | 9 | ### Clean Text Dump 10 | About one million reports are not usable, since they only document DICOM-imports, meetings, consistency tests or the like. These were removed by full and partial string matching. In this way, it was possible to remove a large part of the inappropriate diagnostic texts and reducing the number of text-reports from 4,790,000 to 3,841,543. 11 | 12 | used scripts: 13 | [text-extraction/clean-report-texts.R](text-extraction/clean-report-texts.R) 14 | 15 | ### Converting the Texts to Document Format 16 | For generation of a custom vocabulary and for generation of training-data for BERT, the source-files need to be in a specific document format which is: 17 | 18 | > "The input is a plain text file, with one sentence per line. (It is important that these be actual sentences for the "next sentence prediction" task). Documents are delimited by empty lines." 19 | 20 | As all text files were stored as csv, they need to be converted to document format. Each row did contain a document, therefore pasting the empty line between documents was straightforward, however having only one sentence per line requires the documents to be split by sentence, which is more complicated. 21 | In order to sentencize the reports the German nlp-module of `spaCy` was used. As this did not work perfectly and also split most of the radiology-specific abbreviations an other function to fix those wrong splits was written. 22 | 23 | A [notebook](pretraining/sentencizing.ipynb) on the process and a [python-script](pretraining/run_sentencizing.py) to run the code from the bash can be found in the folder pregeneration. 24 | 25 | ### Create WordPiece vocabulary 26 | Google research does not provide scripts to create a new WordPiece vocabulary. They do refer to other open source options such as: 27 | 28 | - [Google's SentencePiece library](https://github.com/google/sentencepiece) 29 | - [tensor2tensor's WordPiece generation script](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder_build_subword.py) 30 | - [Rico Sennrich's Byte Pair Encoding library](https://github.com/rsennrich/subword-nmt) 31 | 32 | But, as they mentoin, these are not compatible with their `tokenization.py` library. Therefore we used the modified library by [kwonmha](https://github.com/kwonmha/bert-vocab-builder) to build a custom vocabulary. 33 | 34 | A [notebook](pretraining/bert-custom-vocabulary.ipynb) explaining the steps neseccary to create a custom WordPiece vocabulary can be found in the folder pretraining. 35 | 36 | ### Create Pretraining Data 37 | The `create_pretraining_data.py` script form Google was used. Due to memory limitations, the text dump had to be split into smaller parts. The [Notebook](pretraining/notebooks/03_create-pretraining-data.ipynb) gives more details on the procedure of data-preparation. 38 | 39 | ### Run Pretraining 40 | Two models were pretrained using the BERT base configuraton. One was pretrained from scratch, one using a German BERT Model as initial Checkpoint. The [Notebook](pretraining/notebooks/04_run-pretraining.ipynb) explains pretraining in more detail. 41 | 42 | ## Finetuning of four different BERT models 43 | The german BERT Model from deepset.ai, the multilingual BERT model from Google, and our two pretrained BERT models were all fintuned on varing ammounts of annotated text reports of chest radiographs. 44 | The steps of the fine-tuning process are explained in detail in the respective [notebooks](finetuning). 45 | 46 | ## Results 47 | Our BERT models achieve state of the art performance compared to the existing literature. 48 | 49 | | F1- scores  | RAD-BERT | RAD-BERT train size=1000 | [Olatunji et al.  2019](https://arxiv.org/pdf/1905.02283.pdf) | [Reeson et al.  2018](https://www.ncbi.nlm.nih.gov/pubmed/29802131) | [Friedlin et al.  2006](https://www.ncbi.nlm.nih.gov/pubmed/17238345) | [Wang et al.   2017](https://arxiv.org/abs/1705.02315) | MetaMap $   2017 | [Asatryan et al.  2011](https://www.ncbi.nlm.nih.gov/pubmed/21093355) | [Elkin et al.  2008](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2656026/) |  50 | |---|---|---|---|---|---|---|---|---|---|---| 51 | | Globally abnormal  | 0.96 | 0.96 | 0.83 | na  | na  | 0.93+  |0.91+  | na  | na  | 52 | | Cardiomegaly  | na  | na  | 0.23 | na  | 0.97 | 0.88 | 0.9 | na  | na  | 53 | | Congestion  | 0.9 | 0.86 | na  | na  | 0.98 | 0.83§  | 0.77§  | na  | na  | 54 | | Effusion  | 0.92 | 0.92 | na  | na  | 0.98 | 0.87 | 0.81 | na  | na  | 55 | | Opacity/Consolidation | 0.92 | 0.88 | 0.63 | na  | na  | 0.91/0.80/0.77\#  | 0.95/0.39/0.71\# | 0.24-0.57\*  | 0.82 | 56 | | Pneumothorax  | 0.89 | 0.79 | na  | 0.92 | na  | 0.86 | 0.46 | na  | na |  57 | | Venous catheter  | 0.98 | 0.96 | na  | 0.97 | na  | na  | na  | na  | na |  58 | | Thoracic drain  | 0.95 | 0.9 | na  | 0.95 | na  | na  | na  | na  | na  | 59 | | Medical devices  | 0.99 | 0.99 | 0.29 | na  | na  | na  | na  | na  | na  | 60 | | Best F1-score   | 0.99 | 0.99 | 0.83 | 0.95 | 0.98 | 0.93 | 0.95 | 0.57 | 0.82 | 61 | | Worst F1-score  | 0.58 | 0.4 | 0.23 | 0.92 | 0.97 | 0.52 | 0.39 | 0.24 | 0.82 | 62 | 63 | \+ detection of normal radiographs 64 | \# Consolidation/Opacity was not reported but atelectasis, infiltration and pneumonia 65 | \$ As reported by Wang et al.34 66 | \§ Congestion not reported, but edema. 67 | \* performed only pneumonia detection, as a clinical diagnosis and used varying thresholds of pneumonia-prevalence. 68 | _na_ not available/not reported by study 69 | 70 | ## Citation 71 | If you find the code helpful, please cite our published manuscript: 72 | ``` 73 | @article{10.1093/bioinformatics/btaa668, 74 | author = {Bressem, Keno K and Adams, Lisa C and Gaudin, Robert A and Tr√∂ltzsch, Daniel and Hamm, Bernd and Makowski, Marcus R and Schüle, Chan-Yong and Vahldiek, Janis L and Niehues, Stefan M}, 75 | title = "{Highly accurate classification of chest radiographic reports using a deep learning natural language model pretrained on 3.8 million text reports}", 76 | journal = {Bioinformatics}, 77 | year = {2020}, 78 | month = {07}, 79 | issn = {1367-4803}, 80 | doi = {10.1093/bioinformatics/btaa668}, 81 | url = {https://doi.org/10.1093/bioinformatics/btaa668}, 82 | note = {btaa668}, 83 | eprint = {https://academic.oup.com/bioinformatics/article-pdf/doi/10.1093/bioinformatics/btaa668/33526133/btaa668.pdf}, 84 | } 85 | 86 | 87 | -------------------------------------------------------------------------------- /pretraining/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /annotator/manual-english.md: -------------------------------------------------------------------------------- 1 | # Operating Instructions for Text Annotator 2 | 3 | The annotator is a Python script and cannot be executed without Python. Therefore, some preparatory steps are necessary. 4 | 5 | ## 1. Installing von Pyhton3 6 | ### Linux (Ubuntu) 7 | On Linux, Python3 should be installed from the start. If this is not the case, open a terminal with [Ctrl] + [Alt] + [T] and install Python3 with the following command. 8 | 9 | ```bash 10 | sudo apt-get install python3 11 | ``` 12 | 13 | ### Linux (Ubuntu) subsystem for Windows 14 | 15 | Python3 is not preinstalled in the Linux subsystem for Windows. The command is the same as in normal Ubuntu. 16 | 17 | ### Mac 18 | On Mac, it is recommended to install Python3 with __homebrew__: 19 | 20 | #### Installing Homebrew 21 | Open with Command (Mac key) and Space Spotlight and enter 'Terminal'. [Homebrew](https://brew.sh/) can be installed with the following command, which is copied to the terminal: 22 | 23 | ```bash 24 | /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew 25 | /install/master/install)" 26 | ``` 27 | If necessary, brew must be added to the system path (PATH): 28 | 29 | ```bash 30 | export PATH="/usr/local/opt/python/libexec/bin:$PATH" 31 | ``` 32 | #### Installing Python with Homebrew 33 | Again via terminal: 34 | 35 | ```bash 36 | brew install python 37 | ``` 38 | 39 | ## 2. Setting up a virtual environment 40 | 41 | If you ever plan to run other scripts again, it is advisable to work with virtual Python environments. Here pip-venv or Anaconda are suitable. Anaconda has the advantage that it is both package manager and manager for the virtual environments, while with pip-venv you would need pip as package manager additionally. 42 | 43 | ### Installing Anaconda on Linux 44 | Several steps are necessary here, so we would like to refer to the following __[blog post (click)](https://www.digitalocean.com/community/tutorials/how-to-install-anaconda-on-ubuntu-18-04-quickstart). __ 45 | 46 | ### Installing Anaconda on Mac 47 | There are also very good external __[instructions (click)](https://www.datacamp.com/community/tutorials/installing-anaconda-mac-os-x)__ 48 | 49 | ### Setting up a virtual environment 50 | Enter the following code into the terminal: 51 | 52 | ```bash 53 | conda create --name=simple-annotator numpy pandas 54 | conda activate simple annotator 55 | ``` 56 | These commands create a new virtual environment called `simple-annotator` and simultaneously install the pandas and numpy packages with the latest version. These are required by the annotator. 57 | 58 | Alternatively, if you don't want to use virtual environments (e.g. on the Linux subsystem) you can use pip only: 59 | 60 | 61 | ``` bash 62 | pip install --user numpy pandas 63 | ``` 64 | The `--user` flag is useful to prevent system programs from being updated, because e.g. in Ubuntu important system packages depend on Python. The rule is: __Never use pip with sudo!__. 65 | 66 | 67 | ## 3. Starting the Annotator 68 | Unpack the package in any directory, e.g. in the _Documents_ folder. 69 | Now open the terminal, if it is not already open, and activate the previously created environment. Then navigate to the folder. Assuming the folder has been unpacked to _simple-annoator_: _/path/to/simple-annoator_ 70 | 71 | ```bash 72 | conda activate simple-annotator 73 | cd /path/to/simple-annoator 74 | python3 simple-annotator.py 75 | ``` 76 | Please maximize the terminal immediately. 77 | 78 | ## 4. Cache 79 | Every keystroke should trigger an automatic cache, but unfortunately it may happen in rare cases that the program crashes and data may be lost. Therefore, it is advisable to make regular backup copys of the file `file_dir.csv` in the folder `/data`. 80 | 81 | ## 5. Operating the Annotator 82 | The user interface is written in English. 83 | Annotations are made via the keys [1], [2], [3], to [9]. Additionally, there is the possibility to mark findings as unevaluable ([x]), to select [a] or deselect [d] all annotations. In addition, a rating should be given for each annotation indicating your level of certainty (low, medium, high), which may be selected with the up/down arrow keys. [d] deletes the entire selection. 84 | There is also the possibility to jump only between annotated reports or non-annotated reports ([A]). 85 | 86 | ## 6. Guide to the Annotation of Thoraxic Images 87 | Some reports for chest X-rays might contain unclear wordings. To still achieve a high interrater agreement we would suggest the following procedure: 88 | 89 | ### Unevaluable 90 | - If foreign material is described as idem (not changed). 91 | - If less than three findings are mentioned (e.g. only "No pneumothorax after central venous catheter"). If the report would instead read: "No congestion, no infiltrates, no effusion" we would suggest to annotate it as absent pneumothorax. 92 | - Severe dictation errors which blur the meaning of the text report. 93 | 94 | ### Congestion: 95 | Congestion would be indicated by: 96 | - (Low)central congestion signs 97 | - Pulmonary edema/pulmonary fluid retention 98 | 99 | Congestion should __ NOT__ be considered, when there are: 100 | - No signs of higher degrees of congestion 101 | - No congestion when lying down. 102 | - Centrally accentuated fluid retentions without higher degree of congestion. 103 | 104 | ### Attenuation 105 | After annotation of the first 500 texts, we found that it was often not possible to distinguish between infiltrates or dystelectasis. 106 | 107 | Attenuation is present in case of: 108 | - Infiltrates/ dystelectases/ atelectases/ reduced ventilation etc., even if these are described as discrete. 109 | 110 | There is __NO__ attenuation in case of: 111 | - "Masked infiltrates possible in the basal parts of the lung", an exception can be made if this appears to be reasonable in the context: "Question wording: High CRP, RGs, productive cough. Findings/assessment: No confluent infiltrates when lying down. Reduced ventilation in the basal parts of the lung, here masked infiltrates may be present. " 112 | - Lung nodule/tumor 113 | - Foreign material (e.g. metal splinters) 114 | - Pleural shadow 115 | 116 | ### Pleural effusion 117 | Pleural effusion can be: 118 | - Blunting of the costophrenic/cardiophrenic angle, fluid within the horizontal or oblique fissures, seropneumothorax. 119 | 120 | There is __NO__ effusion in case of: 121 | - Homogeneous reduction of transparency (unless evaluated as effusion) 122 | 123 | ### Pneumothorax 124 | Pneumothorax may be indicated by: 125 | - pneu, pneumothorax, seropneumothorax etc. 126 | 127 | There is __NO__ presence of pneumothorax in case of: 128 | - 'No evidence of pneumothorax in a lying position' or 'as far as assessable in a lying position no pneumothorax or similar' 129 | - Deep costophrenic angle, if this is mentioned without evaluation. 130 | 131 | ### Malposition 132 | #### Central venous catheter (CVC) 133 | - CVC, Shaldon, Dialysis catheter, PICC etc. In short, everything that goes into a vein and about which medication could be given. 134 | 135 | #### Tube 136 | - Tube, tracheal cannula, tracheostoma etc. 137 | 138 | #### Thoracic drain 139 | - Thoracic drain, chest tube etc. 140 | 141 | #### Stomach tube 142 | - Stomach tube, tiger tube, feeding tube etc. But not: PEG 143 | 144 | #### Malposition 145 | - CVC: buckling, turned over or other explicit malposition 146 | - Tube: ending < 2 cm above the carina 147 | - Chest tube: kinking or other explicit malposition. Atypical projections do not correspond to malposition. 148 | - Gastric tube: loops outside the stomach, end with projection on the oesophagus etc. 149 | -------------------------------------------------------------------------------- /finetuning/01_binary-classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Train BERT for the Preselection of Reports\t \n", 8 | "\n", 9 | "For fine-tuning, the 'simpletransformers' library is used, as it only requires a few lines of code for the training. The library can be downloaded from Github via: (https://github.com/ThilinaRajapakse/simpletransformers). \n", 10 | "\n", 11 | "## Creating the enviroment\n", 12 | "\n", 13 | "\n", 14 | "```bash\n", 15 | "conda create --name=finetuning \n", 16 | "conda install tensorflow-gpu pytorch scikit-learn\n", 17 | "\n", 18 | "cd transformers \n", 19 | "pip install .\n", 20 | "\n", 21 | "git clone https://github.com/ThilinaRajapakse/simpletransformers\n", 22 | "cd simpletransformers\n", 23 | "pip install .\n", 24 | "\n", 25 | "git clone https://github.com/NVIDIA/apex\n", 26 | "cd apex\n", 27 | "pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n", 28 | "\n", 29 | "conda install ipykernel pandas\n", 30 | "ipython kernel install --user --name=finetuning\n", 31 | "\n", 32 | "pip install python-box ipywidgets\n", 33 | "jupyter nbextension enable --py widgetsnbextension\n", 34 | "```\n", 35 | "'nvidia-apex' raises an error about incompatible CUDA versions. The function to check for errors is commented out.\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import pandas as pd\n", 45 | "from simpletransformers.classification import ClassificationModel" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "Path to folder containing data file. " 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "DATADIR = '../data/'" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "!ls $DATADIR" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "Load the train dataset." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "data = pd.read_csv(DATADIR + 'train-evaluable.csv', header=0)\n", 87 | "data.shape" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "data.sample(10)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Create a MultiLabelClassificationModel\n", 106 | "args={'output_dir': 'outputs/',\n", 107 | " 'cache_dir': 'cache_dir/',\n", 108 | " 'fp16': False,\n", 109 | " 'fp16_opt_level': 'O1',\n", 110 | " 'max_seq_length': 512, \n", 111 | " 'train_batch_size': 8,\n", 112 | " 'gradient_accumulation_steps': 10,\n", 113 | " 'eval_batch_size': 12,\n", 114 | " 'num_train_epochs': 10, \n", 115 | " 'weight_decay': 0,\n", 116 | " 'learning_rate': 4e-5,\n", 117 | " 'adam_epsilon': 1e-8,\n", 118 | " 'warmup_ratio': 0.06,\n", 119 | " 'warmup_steps': 0,\n", 120 | " 'max_grad_norm': 1.0,\n", 121 | " 'logging_steps': 50,\n", 122 | " 'save_steps': 2000, \n", 123 | " 'evaluate_during_training': True,\n", 124 | " 'overwrite_output_dir': True,\n", 125 | " 'reprocess_input_data': True,\n", 126 | " 'n_gpu': 2,\n", 127 | " 'use_multiprocessing': True,\n", 128 | " 'silent': False,\n", 129 | " 'threshold': 0.5,\n", 130 | " 'wandb_project': 'bert-for-radiology',\n", 131 | " \n", 132 | " # for long texts \n", 133 | " 'sliding_window': True,\n", 134 | " 'tie_value': 1}\n", 135 | "\n", 136 | "model_names= ['../models/pt-radiobert-base-german-cased/', 'bert-base-german-cased', '../models/pt-radiobert-from-scratch/', 'bert-base-multilingual-cased']" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "Training the models." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "test = pd.read_csv(DATADIR + 'test-evaluable.csv', header=0)\n", 153 | "args[\"output_dir\"] = 'outputs/final/radbert-binary/'\n", 154 | "model = ClassificationModel('bert', '../models/pt-radiobert-base-german-cased/', args=args)\n", 155 | "model.train_model(data, eval_df = test)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "args[\"output_dir\"] = 'outputs/final/radbert-binary/'\n", 165 | "model = ClassificationModel('bert', '../models/pt-radiobert-base-german-cased/', args=args)\n", 166 | "model.train_model(data)\n", 167 | "\n", 168 | "args[\"output_dir\"] = 'outputs/final/fsbert-binary/'\n", 169 | "model = ClassificationModel('bert', '../models/pt-radiobert-from-scratch/', args=args)\n", 170 | "model.train_model(data)\n", 171 | "\n", 172 | "args[\"output_dir\"] = 'outputs/final/gerbert-binary/'\n", 173 | "model = ClassificationModel('bert', 'bert-base-german-cased', args=args)\n", 174 | "model.train_model(data)\n", 175 | "\n", 176 | "args[\"output_dir\"] = 'outputs/final/multibert-binary/'\n", 177 | "model = ClassificationModel('bert', 'bert-base-multilingual-cased', args=args)\n", 178 | "model.train_model(data)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "test = pd.read_csv(DATADIR + 'test-evaluable.csv', header=0)\n", 188 | "\n", 189 | "with open('results-binary.csv', 'w+') as f:\n", 190 | " f.write('model,' + ','.join(map(str, range(1,501))) + ',\\n')\n", 191 | "\n", 192 | "model_dirs = ['outputs/final/radbert-binary/', 'outputs/final/fsbert-binary/', 'outputs/final/gerbert-binary/', 'outputs/final/multibert-binary/']\n", 193 | "\n", 194 | "for model_dir in model_dirs:\n", 195 | " model = ClassificationModel('bert', model_dir, args=args)\n", 196 | " pred, raw = model.predict(test.text) \n", 197 | " \n", 198 | " for rep in ['outputs', 'final', '/']:\n", 199 | " model_dir=model_dir.replace(rep, '')\n", 200 | " \n", 201 | " with open('results-binary.csv', 'a') as f:\n", 202 | " f.write(model_dir + ',' + ','.join(map(str, raw)).replace('\\n', '') +'\\n') " 203 | ] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "fast-bert", 209 | "language": "python", 210 | "name": "fast-bert" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.7.5" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 4 227 | } 228 | -------------------------------------------------------------------------------- /pretraining/notebooks/04_run-pretraining.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Run Pretraining\n", 8 | "The data is now in the desired format for pre-trainig of BERT. It will not be pre-trained from scratch, as the already existing German model, open-sourced by [deepset.ai](https://deepset.ai/german-bert), does not provide checkpoints. " 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "```bash\n", 16 | "conda create --name=bert-pretraining tensorflow-gpu=1.14\n", 17 | "conda activate bert-pretraining\n", 18 | "```" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Pre-training tips and caveats - from [Google](https://github.com/google-research/bert#pre-training-tips-and-caveats)\n", 26 | "\n", 27 | "> If using your own vocabulary, make sure to change vocab_size in bert_config.json. If you use a larger vocabulary without changing this, you will likely get NaNs when training on GPU or TPU due to unchecked out-of-bounds access.\n", 28 | "\n", 29 | "The vocabulary size is 30,000, as in the German BERT-Base:\n", 30 | "\n", 31 | "> If your task has a large domain-specific corpus available (e.g., \"movie reviews\" or \"scientific papers\"), it will likely be beneficial to run additional steps of pre-training on your corpus, starting from the BERT checkpoint.\n", 32 | "\n", 33 | "> The learning rate we used in the paper was 1e-4. However, if you are doing additional steps of pre-training starting from an existing BERT checkpoint, you should use a smaller learning rate (e.g., 2e-5).\n", 34 | "\n", 35 | "In our case, the learning rate was set to 2e-5. \n", 36 | "\n", 37 | "> Current BERT models are English-only, but we do plan to release a multilingual model which has been pre-trained on a lot of languages in the near future (hopefully by the end of November 2018).\n", 38 | "\n", 39 | "> Longer sequences are disproportionately expensive because attention is quadratic to the sequence length. In other words, a batch of 64 sequences of length 512 is much more expensive than a batch of 256 sequences of length 128. The fully-connected/convolutional cost is the same, but the attention cost is far greater for the 512-length sequences. Therefore, one good recipe is to pre-train for, say, 90,000 steps with a sequence length of 128 and then for 10,000 additional steps with a sequence length of 512. The very long sequences are mostly needed to learn positional embeddings, which can be learned fairly quickly. Note that this does require generating the data twice with different values of max_seq_length.\n", 40 | "\n", 41 | "> If you are pre-training from scratch, be prepared that pre-training is computationally expensive, especially on GPUs. If you are pre-training from scratch, our recommended recipe is to pre-train a BERT-Base on a single preemptible Cloud TPU v2, which takes about 2 weeks at a cost of about $500 USD (based on the pricing in October 2018). You will have to scale down the batch size when only training on a single Cloud TPU, compared to what was used in the paper. It is recommended to use the largest batch size that fits into TPU memory.\n", 42 | "\n", 43 | "However, this cannot be done in our case due to dealing with sensitive patient data. \n", 44 | "\n", 45 | "\n", 46 | "### How many steps should be used?\n", 47 | "Devlin et a. write: \n", 48 | "\n", 49 | "> We train with batch size of 256 sequences (256 sequences * 512 tokens = 128,000 tokens/batch) for 1,000,000 steps, which is approximately 40 epochs over the 3.3 billion word corpus. We use Adam with learning rate of 1e-4, β1 = 0.9, β2 = 0.999, L2 weight decay of 0.01, learning rate warm-up over the first 10,000 steps, and linear decay of the learning rate. \n", 50 | "\n", 51 | "Our batch size was 32 with a sequence length of 128 tokens (32 sequences * 128 tokens = 4,096 tokens/batch), whereby more was not possible with a single GPU(GTX2080ti). The resulting corpus consisted of 415,702,033 words. To achive 40 epochs, approximately 100,000 words were needed. Similar to Devlin et al, we used 1\\% warmup steps. \n", 52 | "For pretraining, files were copied from the Google Repository\n", 53 | "\n", 54 | "```\n", 55 | "run_pretraining.py\n", 56 | "modelin.py\n", 57 | "optimization.py\n", 58 | "```\n", 59 | "to the `/pretraining` folder. \n", 60 | "To execute pre-training, the required code was wrapped in a shell script. " 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "```bash\n", 68 | "#!/bin/bash\n", 69 | "\n", 70 | "# absolute path for pretraining-folder. makes it possible to call the script form every other folder\n", 71 | "cd ~/Documents/bert-for-radiology/pretraining\n", 72 | "\n", 73 | "# activate the anaconda enviroment\n", 74 | "eval \"$(conda shell.bash hook)\"\n", 75 | "conda activate bert-pretraining\n", 76 | "\n", 77 | "# give feedback\n", 78 | "for i in {10..1}\n", 79 | "do \n", 80 | "\tprintf $i\n", 81 | "\tprintf \"\\n\"\n", 82 | "\tsleep 0.5\n", 83 | "done\n", 84 | "printf \"gooo.... \\n\\n\"\n", 85 | "\n", 86 | "# run pre-training\n", 87 | "python run_pretraining.py \\\n", 88 | " --input_file=../tmp/tf_examples.tfrecord-* \\\n", 89 | " --output_dir=../tmp/pretraining_output \\\n", 90 | " --do_train=True \\\n", 91 | " --do_eval=True \\\n", 92 | " --bert_config_file=../models/bert-base-german-cased/bert_config.json \\\n", 93 | " --init_checkpoint=../models/bert-base-german-cased/bert_model.ckpt \\\n", 94 | " --train_batch_size=32 \\\n", 95 | " --max_seq_length=128 \\\n", 96 | " --max_predictions_per_seq=20 \\\n", 97 | " --num_train_steps=90000 \\\n", 98 | " --num_warmup_steps=10000 \\\n", 99 | " --learning_rate=2e-5\n", 100 | "```\n", 101 | "\n", 102 | "The shell script automatically activates the required environment:\n", 103 | "\n" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "!bash run_pretraining.sh" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Evaluate\n", 120 | "Call from other enviroment. \n", 121 | "\n", 122 | "```bash\n", 123 | "python run_pretraining.py \\\n", 124 | " --input_file=../tmp/tf_examples.tfrecord-* \\\n", 125 | " --output_dir=../tmp/pretraining_output \\\n", 126 | " --do_train=False \\\n", 127 | " --do_eval=True \\\n", 128 | " --do_predict=True \\\n", 129 | " --bert_config_file=../models/bert-base-german-cased/bert_config.json \\\n", 130 | " --init_checkpoint=../models/bert-base-german-cased/bert_model.ckpt \\\n", 131 | " --max_predictions_per_seq=20 \n", 132 | "```" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "bert-pretraining", 146 | "language": "python", 147 | "name": "bert-pretraining" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.6.9" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 4 164 | } 165 | -------------------------------------------------------------------------------- /pretraining/bert-vocab-builder/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple invertible tokenizer. 17 | 18 | Converts from a unicode string to a list of tokens 19 | (represented as Unicode strings). 20 | 21 | This tokenizer has the following desirable properties: 22 | - It is invertible. 23 | - Alphanumeric characters are broken away from non-alphanumeric characters. 24 | - A single space between words does not produce an extra token. 25 | - The full Unicode punctuation and separator set is recognized. 26 | 27 | The tokenization algorithm is as follows: 28 | 29 | 1. Split the text into a list of tokens, splitting at every boundary of an 30 | alphanumeric character and a non-alphanumeric character. This produces 31 | a list which alternates between "alphanumeric tokens" 32 | (strings of alphanumeric characters) and "non-alphanumeric tokens" 33 | (strings of non-alphanumeric characters). 34 | 35 | 2. Remove every token consisting of a single space, unless it is 36 | the very first or very last token in the list. These tokens are now 37 | implied by the fact that there are two adjacent alphanumeric tokens. 38 | 39 | e.g. u"Dude - that's so cool." 40 | -> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] 41 | """ 42 | 43 | from __future__ import absolute_import 44 | from __future__ import division 45 | from __future__ import print_function 46 | 47 | import collections 48 | import sys 49 | import unicodedata 50 | import six 51 | from six.moves import range # pylint: disable=redefined-builtin 52 | # from tensor2tensor.utils import mlperf_log 53 | import tensorflow as tf 54 | import time 55 | 56 | # Conversion between Unicode and UTF-8, if required (on Python2) 57 | _native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s) 58 | 59 | 60 | # This set contains all letter and number characters. 61 | _ALPHANUMERIC_CHAR_SET = set( 62 | six.unichr(i) for i in range(sys.maxunicode) 63 | if (unicodedata.category(six.unichr(i)).startswith("L") or 64 | unicodedata.category(six.unichr(i)).startswith("N") or 65 | unicodedata.category(six.unichr(i)).startswith("P"))) 66 | # unicodedata.category(six.unichr(i)).startswith("S") 67 | 68 | 69 | def encode(text): 70 | """Encode a unicode string as a list of tokens. 71 | 72 | Args: 73 | text: a unicode string 74 | Returns: 75 | a list of tokens as Unicode strings 76 | """ 77 | if not text: 78 | return [] 79 | ret = [] 80 | token_start = 0 81 | # Classify each character in the input string 82 | is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] 83 | add_remaining = False 84 | for pos in range(1, len(text)): 85 | add_remaining = False 86 | if is_alnum[pos] != is_alnum[pos - 1]: 87 | if not is_alnum[pos]: 88 | token = text[token_start:pos] 89 | if token != u" " or token_start == 0: 90 | add_remaining = False 91 | ret.append(token) 92 | else: 93 | add_remaining = True 94 | token_start = pos 95 | 96 | final_token = text[token_start:] if text[-1] in _ALPHANUMERIC_CHAR_SET else text[token_start:-1] 97 | if add_remaining: 98 | ret.append(final_token) 99 | return ret 100 | 101 | 102 | def decode(tokens): 103 | """Decode a list of tokens to a unicode string. 104 | 105 | Args: 106 | tokens: a list of Unicode strings 107 | Returns: 108 | a unicode string 109 | """ 110 | token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] 111 | ret = [] 112 | for i, token in enumerate(tokens): 113 | if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: 114 | ret.append(u" ") 115 | ret.append(token) 116 | return "".join(ret) 117 | 118 | 119 | def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True): 120 | """Reads files matching a wildcard pattern, yielding the contents. 121 | 122 | Args: 123 | filepattern: A wildcard pattern matching one or more files. 124 | max_lines: If set, stop reading after reading this many lines. 125 | split_on_newlines: A boolean. If true, then split files by lines and strip 126 | leading and trailing whitespace from each line. Otherwise, treat each 127 | file as a single string. 128 | 129 | Yields: 130 | The contents of the files as lines, if split_on_newlines is True, or 131 | the entire contents of each file if False. 132 | """ 133 | filenames = sorted(tf.io.gfile.glob(filepattern)) 134 | print(filenames) 135 | lines_read = 0 136 | for filename in filenames: 137 | start = time.time() 138 | with tf.gfile.Open(filename) as f: 139 | if split_on_newlines: 140 | for line in f: 141 | yield line.strip() 142 | lines_read += 1 143 | if max_lines and lines_read >= max_lines: 144 | return 145 | if lines_read % 100000 == 0: 146 | print("read", lines_read, "lines,", time.time() - start, "secs elapsed") 147 | 148 | else: 149 | if max_lines: 150 | doc = [] 151 | for line in f: 152 | doc.append(line) 153 | lines_read += 1 154 | if max_lines and lines_read >= max_lines: 155 | yield "".join(doc) 156 | return 157 | yield "".join(doc) 158 | 159 | else: 160 | yield f.read() 161 | 162 | print(time.time() - start, "for reading read file :", filename) 163 | 164 | 165 | def corpus_token_counts( 166 | text_filepattern, corpus_max_lines, split_on_newlines=True, additional_chars=""): 167 | """Read the corpus and compute a dictionary of token counts. 168 | 169 | Args: 170 | text_filepattern: A pattern matching one or more files. 171 | corpus_max_lines: An integer; maximum total lines to read. 172 | split_on_newlines: A boolean. If true, then split files by lines and strip 173 | leading and trailing whitespace from each line. Otherwise, treat each 174 | file as a single string. 175 | additional_chars: A String. Each consisting characters will be treat as normal 176 | alphabets so that they will be included in each vocab. 177 | 178 | Returns: 179 | a dictionary mapping token to count. 180 | """ 181 | if additional_chars: 182 | _ALPHANUMERIC_CHAR_SET.add(additional_chars) 183 | 184 | counts = collections.Counter() 185 | for doc in _read_filepattern( 186 | text_filepattern, 187 | max_lines=corpus_max_lines, 188 | split_on_newlines=split_on_newlines): 189 | counts.update(encode(_native_to_unicode(doc))) 190 | print("read all files") 191 | return counts 192 | 193 | 194 | def vocab_token_counts(text_filepattern, max_lines): 195 | """Read a vocab file and return a dictionary of token counts. 196 | 197 | Reads a two-column CSV file of tokens and their frequency in a dataset. The 198 | tokens are presumed to be generated by encode() or the equivalent. 199 | 200 | Args: 201 | text_filepattern: A pattern matching one or more files. 202 | max_lines: An integer; maximum total lines to read. 203 | 204 | Returns: 205 | a dictionary mapping token to count. 206 | """ 207 | ret = {} 208 | for i, line in enumerate( 209 | _read_filepattern(text_filepattern, max_lines=max_lines)): 210 | if "," not in line: 211 | tf.logging.warning("Malformed vocab line #%d '%s'", i, line) 212 | continue 213 | 214 | token, count = line.rsplit(",", 1) 215 | ret[_native_to_unicode(token)] = int(count) 216 | 217 | return ret 218 | -------------------------------------------------------------------------------- /rule-based-classification/rule-based-algorithm.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Constructing a Rule Based Algorithm to Classify Radiology Text Reports" 3 | author: Keno Bressem 4 | output: github_document 5 | --- 6 | In order to ensure the most accurate possible comparability of the rule-based algorithm developed here, the entire algorithm is developed only on the annotated data sets. 7 | 8 | ### preliminary considerations 9 | Theoretically, it should be possible to create a very accurate (even more accurate then BERT based algorithms) rule-based algorithm for the classification of text-reports. However, this poses some difficulties. An exact knowledge of the text structure as well as the findings structure is necessary to recognize pitfalls. 10 | 11 | #### Unclear statements 12 | Radiologists do not always express themselves clearly, but often use blurred formulations which cannot always be assigned to a clear statement. 13 | 14 | **Example**: 15 | 16 | > "Maskierte Infiltrate in den basalen, minderbelüfteten Arealen möglich" 17 | 18 | Which roughly translates to: 19 | 20 | > "Masked infiltrates possible in basal, poorly ventilated areas." 21 | 22 | This wording, which is often found in reoports of ICU chest X-rays, expresses either the radiologists uncertainty that he or she suspects infiltrates as cause for the patients elevated infection levels, but cannot clearly delimit them because the lungs are not sufficiently ventilated for this OR, and this is what clinical practice unfortunately often shows us, she/he just does mot want to make a specific descission and maybe falsly rule out pneumonia, even though the X-ray does not actually show any clear indications for the presence of pneumonia. 23 | Clinical practice shows that the above formulation rather means that no infiltrates can be distinguished in the image. 24 | 25 | #### Insufficient and therefore unevaluable reports 26 | Findings are considered to be unevaluable if they do not contain sufficient information. E.g. if no explicit statement is made about the location of the foreign materials, but these are only described as "no change of foreign materials" (German: "Fremdmaterial idem"). However, the word "idem" should not always mark a report as unevaluable, as it can also be used if no change occured. 27 | This is particualry tricky and no simple rule can be established, as even perfectly sufficient report text might contain those expressions. 28 | 29 | **Examples**: 30 | 31 | > Kein Pneumothorax nach Drainagenanlage. Darüber hinaus kein Befundwandel 32 | 33 | > No pneumothorax after inserting a thoracic drain. No further change. 34 | 35 | not evaluable. 36 | 37 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. Fremdmaterial idem. 38 | 39 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign material idem. 40 | 41 | not evaluable. No statement regarding pneumothorax, still it can be assumed that there is none, since such an important but rare finding is probably always reported if present but often not reported if not present. However, one cannot draw a conclusion on the therapy aids such as catheters, drains etc. 42 | 43 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. Fremdmaterial idem (ZVK, TK). 44 | 45 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign material idem (zentral venous catheter, tracheal cannula). 46 | 47 | Evaluable. The foreign materials are named explicitly, but this occures not so often. 48 | 49 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. Neuer ZVK von rechts, kein Pneumothorax. Sonst Fremdmaterial idem. 50 | 51 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign material idem (zentral venous catheter, tracheal cannula). 52 | 53 | not evaluable, as there is no information about the other unchanged foreign bodies. 54 | 55 | #### Unlikely differential diagnoses 56 | If a differntial diagnosis is mentioned, it is difficult to exclude this string: 57 | 58 | **Example**: 59 | 60 | > Rechts zentral betonte flächige Verschattung, vereinbar mit pneumonischer Infiltration DD bei vermehrter Gefäßzeichnung stauungsbedingt 61 | 62 | > Right central shading, compatible with pneumatic infiltration DD with increased vascular drawing due to congestion 63 | 64 | This is probably no congestion (congestion should be symmetrical) and should thus be annotated so. 65 | 66 | 67 | ```{r} 68 | library(tidyverse) 69 | library(magrittr) 70 | ``` 71 | 72 | ### Loading the text-files stored as single R file: 73 | ```{r} 74 | textreports <- read_csv("/media/bressekk/0B36170D0B36170D/Textdaten_csv/all_data.csv") 75 | ``` 76 | 77 | 78 | ### Removal of unevaluable reports 79 | Since the annotations already contained whether a finding was useful or not, this was done in advance. In the Final Model, however, this will be built in. 80 | 81 | Removal by fixed strings: 82 | ```{r} 83 | unevaluable = "Fremdmaterial idem|kurzfristigen Verlauf kein Befundwandel" 84 | textreports %<>% filter(!str_detect(text, unevaluable)) 85 | ``` 86 | 87 | ### Set up dataframe for annotations 88 | ```{r} 89 | set.seed(081219) 90 | train <- sample(1:nrow(textreports), round(0.8*nrow(textreports)), F) 91 | 92 | 93 | annotations <- tibble(stauung = rep(NA, nrow(textreports)), 94 | erguss = NA, 95 | verschattung = NA, 96 | pneumothorax = NA, 97 | zvk = NA, 98 | thoraxdrainage = NA, 99 | magensonde = NA, 100 | tubus = NA, 101 | fehllage = NA) 102 | ``` 103 | 104 | ### Congestion (German: Stauung) 105 | 106 | If the finding should not be mentioned, it is assumed the finding is not present (reasonable as unevaluable reports have previously been excluded) 107 | 108 | ```{r} 109 | congestion_names <- "stauung|dekompensat|volumen|flüssigkeitseinlag|gestaut|ödem" 110 | 111 | annotations$stauung <- ifelse( 112 | str_detect( 113 | str_to_lower(textreports$text), 114 | congestion_names), 115 | annotations$stauung, 116 | 0) 117 | ``` 118 | 119 | The findings are all printed out in the terminal, the text is then manually evaluated and the strings are then copied to the vectors for `negations` or `positive_finding`. 120 | ```{r message=FALSE, warning=FALSE} 121 | negations <- read_csv("negative_congestion.csv", col_names = F) %>% 122 | select("X1") %>% 123 | unlist() %>% 124 | str_to_lower() 125 | 126 | positive_finding <- read_csv("positive_congestion.csv", col_names = F) %>% 127 | select("X1") %>% 128 | unlist() %>% 129 | str_to_lower() 130 | ``` 131 | 132 | Positve findings should be evaluated first. "Pulmonale Stauung" and "Keine pulmonale Stauung" will be rated as positive, but the second then be again labeled as negative in the next loop. 133 | ```{r message=FALSE, warning=FALSE} 134 | for (str in positive_finding) { 135 | annotations[train,]$stauung <- ifelse( 136 | str_detect( 137 | str_to_lower(textreports[train,]$text), 138 | str), 139 | 1, 140 | annotations[train,]$stauung) 141 | } 142 | 143 | for (str in negations) { 144 | annotations[train,]$stauung <- ifelse( 145 | str_detect( 146 | str_to_lower(textreports[train,]$text), 147 | str), 148 | 0, 149 | annotations[train,]$stauung) 150 | } 151 | ``` 152 | 153 | Evaluate accuracy of rule based algorithm on training data 154 | ```{r} 155 | mean(annotations[train,]$stauung == textreports[train, ]$Stauung, na.rm = T) 156 | ``` 157 | missed annotations on training data 158 | ```{r} 159 | mean(is.na(annotations[train,]$stauung)) 160 | ``` 161 | accuray with missed set to 0/FALSE 162 | ```{r} 163 | annotations[train,]$stauung <- ifelse(is.na(annotations[train,]$stauung), 164 | 0, 165 | annotations[train,]$stauung) 166 | 167 | mean(annotations[train,]$stauung == textreports[train, ]$Stauung, na.rm = T) 168 | ``` 169 | 170 | Evaluate accuracy of rule based algorithm on test data 171 | ```{r} 172 | for (str in positive_finding) { 173 | annotations[-train,]$stauung <- ifelse( 174 | str_detect( 175 | str_to_lower( 176 | textreports[-train,]$text), str), 177 | 1, 178 | annotations[-train,]$stauung) } 179 | 180 | for (str in negations) { 181 | annotations[-train,]$stauung <- ifelse( 182 | str_detect( 183 | str_to_lower( 184 | textreports[-train,]$text), str), 185 | 0, 186 | annotations[-train,]$stauung) } 187 | 188 | mean(annotations[-train,]$stauung == textreports[-train, ]$Stauung, na.rm = T) 189 | ``` 190 | 191 | missed annotations on test data 192 | ```{r} 193 | mean(is.na(annotations[-train,]$stauung)) 194 | ``` 195 | 196 | accuray with missed set to 0/FALSE 197 | ```{r} 198 | annotations[-train,]$stauung <- ifelse(is.na(annotations[-train,]$stauung), 199 | 1, 200 | annotations[-train,]$stauung) 201 | 202 | mean(annotations[-train,]$stauung == textreports[-train, ]$Stauung, na.rm = T) 203 | ``` 204 | 205 | 206 | -------------------------------------------------------------------------------- /finetuning/02_multilabel-classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# BERT Fine-tuning on Report Texts\n", 8 | "\n", 9 | "For fine-tuning, the 'simpletransformers' library is used, as it only requires a few lines of code for the training. The library can be downloaded from Github via: https://github.com/ThilinaRajapakse/simpletransformers. \n", 10 | "\n", 11 | "## Creating the enviroment\n", 12 | "\n", 13 | "\n", 14 | "```bash\n", 15 | "conda create --name=finetuning \n", 16 | "conda install tensorflow-gpu pytorch scikit-learn\n", 17 | "\n", 18 | "cd transformers \n", 19 | "pip install .\n", 20 | "\n", 21 | "git clone https://github.com/ThilinaRajapakse/simpletransformers\n", 22 | "cd simpletransformers\n", 23 | "pip install .\n", 24 | "\n", 25 | "git clone https://github.com/NVIDIA/apex\n", 26 | "cd apex\n", 27 | "pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n", 28 | "\n", 29 | "conda install ipykernel pandas\n", 30 | "ipython kernel install --user --name=finetuning\n", 31 | "\n", 32 | "pip install python-box ipywidgets\n", 33 | "jupyter nbextension enable --py widgetsnbextension\n", 34 | "```\n", 35 | "'nvidia-apex' raises an error about incompatible CUDA versions. The function to check for errors was commented out. \n", 36 | "\n", 37 | "\n", 38 | "\n", 39 | "## Data preparation\n", 40 | "Before fine-tuning, the data needs to be pre-processed. The 'simpletransformers' library requires the following data-structure:\n", 41 | "\n", 42 | "| text | labels |\n", 43 | "|------|--------|\n", 44 | "| 'some text for finetuning' | \\[1,0,0,1,1,0,1] |\n", 45 | "| 'some more texts for finetuning' | \\[0,0,0,0,0,1,1] |\n", 46 | "| 'even more texts for finetuning' | \\[0,1,0,1,0,0,1] |" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import pandas as pd\n", 56 | "from sklearn.model_selection import train_test_split\n", 57 | "from simpletransformers.classification import MultiLabelClassificationModel\n", 58 | "from statistics import mean, median, stdev" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Path to folder containing data file. " 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "DATADIR = '../data/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "!ls $DATADIR" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "Load the test data-set and bring it into the desired format (as specified above). " 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "to_drop=['Filename', 'Annotator', 'Confidence']\n", 100 | "data = pd.read_csv(DATADIR + 'train.csv', header=0).drop(to_drop,axis=1)\n", 101 | "labels=['Stauung','Verschattung','Erguss','Pneumothorax','Thoraxdrainage','ZVK','Magensonde','Tubus','Materialfehllage']\n", 102 | "data['labels']=data[labels].values.tolist()\n", 103 | "data=data.drop(labels, axis = 1)\n", 104 | "data.shape" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "data.sample(10)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "test = pd.read_csv(DATADIR + 'test.csv', header=0).drop(to_drop,axis=1)\n", 123 | "test['labels']=test[labels].values.tolist()\n", 124 | "test=test.drop(labels, axis = 1)\n", 125 | "test.sample(10)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "Define functions for performance measurements. " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# Create a MultiLabelClassificationModel\n", 142 | "args={'output_dir': 'outputs/',\n", 143 | " 'cache_dir': 'cache_dir/',\n", 144 | " 'fp16': False,\n", 145 | " 'fp16_opt_level': 'O1',\n", 146 | " 'max_seq_length': 512, \n", 147 | " 'train_batch_size': 8,\n", 148 | " 'gradient_accumulation_steps': 1,\n", 149 | " 'eval_batch_size': 12,\n", 150 | " 'num_train_epochs': 4, \n", 151 | " 'weight_decay': 0,\n", 152 | " 'learning_rate': 4e-5,\n", 153 | " 'adam_epsilon': 1e-8,\n", 154 | " 'warmup_ratio': 0.06,\n", 155 | " 'warmup_steps': 0,\n", 156 | " 'max_grad_norm': 1.0,\n", 157 | " 'logging_steps': 50,\n", 158 | " 'save_steps': 2000, \n", 159 | " 'evaluate_during_training': False,\n", 160 | " 'overwrite_output_dir': True,\n", 161 | " 'reprocess_input_data': True,\n", 162 | " 'n_gpu': 2,\n", 163 | " 'use_multiprocessing': True,\n", 164 | " 'silent': False,\n", 165 | " 'threshold': 0.5,\n", 166 | " \n", 167 | " # for long texts \n", 168 | " 'sliding_window': True,\n", 169 | " 'tie_value': 1}\n", 170 | "\n", 171 | "model_names= ['../models/pt-radiobert-base-german-cased/', 'bert-base-german-cased', '../models/pt-radiobert-from-scratch/', 'bert-base-multilingual-cased']" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "with open(\"results.csv\", 'w+') as f:\n", 181 | " f.write('train_size,model,' + ','.join(map(str, range(1,501))) + ',\\n')\n", 182 | "\n", 183 | "for i in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2250, 2500, 2750, 3000, 3250, 3500, 3750, 4000, 4500]:\n", 184 | " train = data.sample(i)\n", 185 | " for model_name in model_names:\n", 186 | "\n", 187 | " model = MultiLabelClassificationModel('bert', model_name, num_labels=9, args=args)\n", 188 | " model.train_model(train)\n", 189 | " result, model_outputs, wrong_predictions = model.eval_model(test)\n", 190 | " pred, raw = model.predict(test.text) \n", 191 | " \n", 192 | " for rep in ['models', '..', \"/\"]:\n", 193 | " model_name=model_name.replace(rep, '')\n", 194 | " \n", 195 | " with open('results.csv', 'a') as f:\n", 196 | " f.write(str(train.shape[0]) + ',' + model_name + ',' + ','.join(map(str, raw)).replace('\\n', '') +'\\n') \n", 197 | " \n", 198 | " !git add *\n", 199 | " !git commit -m \"update accuracy\"\n", 200 | " !git push" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "Saving the final models, trained on the whole train data-set. " 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "out_dirs = ['outputs/final/radbert/', 'outputs/final/gerbert/', 'outputs/final/fsbert/', 'outputs/final/multibert/']\n", 217 | "\n", 218 | "for i in range(3,4):\n", 219 | " args[\"output_dir\"] = out_dirs[i]\n", 220 | " model = MultiLabelClassificationModel('bert', model_names[i], args=args, num_labels=9)\n", 221 | " model.train_model(data)\n", 222 | " pred, raw = model.predict(test.text) \n", 223 | " \n", 224 | " with open('results.csv', 'a') as f:\n", 225 | " f.write(str(data.shape[0]) + ',' + model_names[i] + ',' + ','.join(map(str, raw)).replace('\\n', '') +'\\n') " 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "# Evaluation on long texts" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "On short report texts, our model does only outperform the standard german BERT model or the multilingual BERT model if the training-set for fine-tuning is very small. On larger train-sizes the value of pretraining is low. \n", 240 | "However, as the vocabulary of the models differs significantly, we believe, that our model will perform better on longer report texts, due to a more efficient tokenization e.g. in the context of text reports for computed tomography. " 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "test = pd.read_csv(DATADIR + 'ct.csv', header=0).drop(to_drop,axis=1)\n", 250 | "test['labels']=test[labels].values.tolist()\n", 251 | "test=test.drop(labels, axis = 1)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "from simpletransformers.classification import ClassificationModel\n", 261 | "\n", 262 | "with open('results-long-text.csv', 'w+') as f:\n", 263 | " f.write('train_size,model,' + ','.join(map(str, range(1,165))) + ',\\n')\n", 264 | "\n", 265 | "model_dirs = ['outputs/final/radbert/', 'outputs/final/fsbert/', 'outputs/final/gerbert/', 'outputs/final/multibert/']\n", 266 | "\n", 267 | "for model_dir in model_dirs:\n", 268 | " model = ClassificationModel('bert', model_dir, args=args)\n", 269 | " pred, raw = model.predict(test.text) \n", 270 | " \n", 271 | " for rep in ['outputs', 'final', '/']:\n", 272 | " model_dir=model_dir.replace(rep, '')\n", 273 | " \n", 274 | " with open('results-long-text.csv', 'a') as f:\n", 275 | " f.write(str(4000) + ',' + model_dir + ',' + ','.join(map(str, raw)).replace('\\n', '') +'\\n') " 276 | ] 277 | } 278 | ], 279 | "metadata": { 280 | "kernelspec": { 281 | "display_name": "fast-bert", 282 | "language": "python", 283 | "name": "fast-bert" 284 | }, 285 | "language_info": { 286 | "codemirror_mode": { 287 | "name": "ipython", 288 | "version": 3 289 | }, 290 | "file_extension": ".py", 291 | "mimetype": "text/x-python", 292 | "name": "python", 293 | "nbconvert_exporter": "python", 294 | "pygments_lexer": "ipython3", 295 | "version": "3.7.5" 296 | } 297 | }, 298 | "nbformat": 4, 299 | "nbformat_minor": 4 300 | } 301 | -------------------------------------------------------------------------------- /rule-based-classification/rule-based-algorithm.md: -------------------------------------------------------------------------------- 1 | Constructing a Rule-Based Algorithm to Classify Radiology Text Reports 2 | ================ 3 | 4 | In order to ensure the best possible comparability of the 5 | rule-based algorithm developed here, the entire algorithm is developed 6 | only on the annotated data sets. 7 | 8 | ### Preliminary considerations 9 | 10 | Theoretically, it should be possible to create a very accurate (even 11 | more accurate than BERT-based algorithms) rule-based algorithm for the 12 | classification of text reports. However, this is challenging. 13 | Accurate knowledge of the text structure as well as the findings 14 | structure is necessary to recognize pitfalls. 15 | 16 | #### Unclear statements 17 | 18 | Radiological text reports are not always clearly formulated and sometimes 19 | leave room for interpretation. 20 | 21 | **Example**: 22 | 23 | > “Maskierte Infiltrate in den basalen, minderbelüfteten Arealen 24 | > möglich” 25 | 26 | Roughly translating into: 27 | 28 | > “Masked infiltrates possible in basal, poorly ventilated areas.” 29 | 30 | This wording, which is often found in reports of intensive care chest X-rays, 31 | expresses either the radiologist's uncertainty about 32 | infiltrates as a cause for the patient's elevated infection levels due do insufficient 33 | ventilation of the lungs OR, that she/he does not wish to make a clear decision in ruling out pneumonia, even though the X-ray does not actually show any clear indications for its presence. 34 | Clinical practice shows that the above-referenced formulation rather indicates that no clear signs of pneumonia can be seen in the chest X-ray. 35 | 36 | #### Insufficient and therefore unevaluable reports 37 | 38 | Findings are considered to be unevaluable if they do not contain 39 | sufficient information, e.g. if no explicit statement is made about the 40 | location of the foreign materials, but these are instead described as “no 41 | change of foreign materials” (German: “Fremdmaterial idem”). However, 42 | the word “idem” should not always mark a report as unevaluable, as it 43 | can also be used if no change occured. This is particularly tricky and no 44 | simple rule can be established, as even perfectly sufficient report texts 45 | might contain this expression. 46 | 47 | **Examples**: 48 | 49 | > Kein Pneumothorax nach Drainagenanlage. Darüber hinaus kein 50 | > Befundwandel 51 | 52 | > No pneumothorax after inserting a chest tube. No further change. 53 | 54 | Not evaluable. 55 | 56 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. 57 | > Fremdmaterial idem. 58 | 59 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign 60 | > material idem. 61 | 62 | Not evaluable. No statement regarding pneumothorax, still it can be 63 | assumed that there is none, since such an important but rare finding is 64 | probably always reported if present , but often not reported if not. However, one cannot draw a conclusion regarding the therapy aids such as catheters, drains etc. 65 | 66 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. 67 | > Fremdmaterial idem (ZVK, TK). 68 | 69 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign 70 | > material idem (zentral venous catheter, tracheal cannula). 71 | 72 | Evaluable. The foreign materials are named explicitly, but this is not always the case. 73 | 74 | > Herz nicht verbreitert. Keine Stauung. Kein Erguss. Kein Infiltrat. 75 | > Neuer ZVK von rechts, kein Pneumothorax. Sonst Fremdmaterial idem. 76 | 77 | > Heart not enlarged. No congestion. No effusion. No infiltrate. Foreign 78 | > material idem (zentral venous catheter, tracheal cannula). 79 | 80 | Not evaluable, as there is no information about the other unchanged 81 | materials. 82 | 83 | #### Unlikely differential diagnoses 84 | 85 | If a differential diagnosis is mentioned, it is difficult to exclude this 86 | string: 87 | 88 | **Example**: 89 | 90 | > Rechts zentral betonte flächige Verschattung, vereinbar mit 91 | > pneumonischer Infiltration DD bei vermehrter Gefäßzeichnung 92 | > stauungsbedingt 93 | 94 | > Right central shading, compatible with pneumatic infiltration DD with 95 | > increased vascular drawing due to congestion 96 | 97 | In this case, there is probably no congestion (congestion should be symmetrical) and 98 | it should thus be annotated accordingly. 99 | 100 | 101 | ``` r 102 | library(tidyverse) 103 | ``` 104 | 105 | ## ── Attaching packages ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.2.1 ── 106 | 107 | ## ✔ ggplot2 3.2.1 ✔ purrr 0.3.2 108 | ## ✔ tibble 2.1.3 ✔ dplyr 0.8.3 109 | ## ✔ tidyr 1.0.0 ✔ stringr 1.4.0 110 | ## ✔ readr 1.3.1 ✔ forcats 0.4.0 111 | 112 | ## ── Conflicts ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ── 113 | ## ✖ dplyr::filter() masks stats::filter() 114 | ## ✖ dplyr::lag() masks stats::lag() 115 | 116 | ``` r 117 | library(magrittr) 118 | ``` 119 | 120 | ## 121 | ## Attaching package: 'magrittr' 122 | 123 | ## The following object is masked from 'package:purrr': 124 | ## 125 | ## set_names 126 | 127 | ## The following object is masked from 'package:tidyr': 128 | ## 129 | ## extract 130 | 131 | ### Loading the text files stored as single R file: 132 | 133 | ``` r 134 | textreports <- read_csv("/media/bressekk/0B36170D0B36170D/Textdaten_csv/all_data.csv") 135 | ``` 136 | 137 | ## Parsed with column specification: 138 | ## cols( 139 | ## Filename = col_character(), 140 | ## Stauung = col_double(), 141 | ## Verschattung = col_double(), 142 | ## Erguss = col_double(), 143 | ## Pneumothorax = col_double(), 144 | ## Thoraxdrainage = col_double(), 145 | ## ZVK = col_double(), 146 | ## Magensonde = col_double(), 147 | ## Tubus = col_double(), 148 | ## Materialfehllage = col_double(), 149 | ## Annotator = col_character(), 150 | ## Confidence = col_character(), 151 | ## text = col_character(), 152 | ## DICOM_path = col_character() 153 | ## ) 154 | 155 | ### Removal of unevaluable reports 156 | 157 | Since the annotations already contain if a finding was useful or 158 | not, this was done in advance. In the final model, however, this will be 159 | built in. 160 | 161 | Removal by fixed 162 | strings: 163 | 164 | ``` r 165 | unevaluable = "Fremdmaterial idem|kurzfristigen Verlauf kein Befundwandel" 166 | textreports %<>% filter(!str_detect(text, unevaluable)) 167 | ``` 168 | 169 | ### Set up dataframe for annotations 170 | 171 | ``` r 172 | set.seed(081219) 173 | train <- sample(1:nrow(textreports), round(0.8*nrow(textreports)), F) 174 | 175 | 176 | annotations <- tibble(stauung = rep(NA, nrow(textreports)), 177 | erguss = NA, 178 | verschattung = NA, 179 | pneumothorax = NA, 180 | zvk = NA, 181 | thoraxdrainage = NA, 182 | magensonde = NA, 183 | tubus = NA, 184 | fehllage = NA) 185 | ``` 186 | 187 | ### Congestion (German: Stauung) 188 | 189 | In case the finding is not mentioned, it is assumed that it is not 190 | present (reasonable as unevaluable reports have previously been 191 | excluded. 192 | 193 | ``` r 194 | congestion_names <- "stauung|dekompensat|volumen|flüssigkeitseinlag|gestaut|ödem" 195 | 196 | annotations$stauung <- ifelse( 197 | str_detect( 198 | str_to_lower(textreports$text), 199 | congestion_names), 200 | annotations$stauung, 201 | 0) 202 | ``` 203 | 204 | The findings are all printed out in the terminal, the text is 205 | manually evaluated and the strings are then copied to the vectors for 206 | `negations` or `positive_finding`. 207 | 208 | ``` r 209 | negations <- read_csv("negative_congestion.csv", col_names = F) %>% 210 | select("X1") %>% 211 | unlist() %>% 212 | str_to_lower() 213 | 214 | positive_finding <- read_csv("positive_congestion.csv", col_names = F) %>% 215 | select("X1") %>% 216 | unlist() %>% 217 | str_to_lower() 218 | ``` 219 | 220 | Positive findings should be evaluated first. “congestion” and 221 | “no congestion” will be rated as positive initially, but the second 222 | finding will then be labeled as negative in the next loop. 223 | 224 | ``` r 225 | for (str in positive_finding) { 226 | annotations[train,]$stauung <- ifelse( 227 | str_detect( 228 | str_to_lower(textreports[train,]$text), 229 | str), 230 | 1, 231 | annotations[train,]$stauung) 232 | } 233 | 234 | for (str in negations) { 235 | annotations[train,]$stauung <- ifelse( 236 | str_detect( 237 | str_to_lower(textreports[train,]$text), 238 | str), 239 | 0, 240 | annotations[train,]$stauung) 241 | } 242 | ``` 243 | 244 | Evaluate accuracy of rule-based algorithm on training 245 | data 246 | 247 | ``` r 248 | mean(annotations[train,]$stauung == textreports[train, ]$Stauung, na.rm = T) 249 | ``` 250 | 251 | ## [1] 0.9010582 252 | 253 | Missed annotations on training data 254 | 255 | ``` r 256 | mean(is.na(annotations[train,]$stauung)) 257 | ``` 258 | 259 | ## [1] 0.012023 260 | 261 | Accuray with missed set to 262 | 0/FALSE 263 | 264 | ``` r 265 | annotations[train,]$stauung <- ifelse(is.na(annotations[train,]$stauung), 266 | 0, 267 | annotations[train,]$stauung) 268 | 269 | mean(annotations[train,]$stauung == textreports[train, ]$Stauung, na.rm = T) 270 | ``` 271 | 272 | ## [1] 0.9012023 273 | 274 | Evaluate accuracy of rule based algorithm on test data 275 | 276 | ``` r 277 | for (str in positive_finding) { 278 | annotations[-train,]$stauung <- ifelse( 279 | str_detect( 280 | str_to_lower( 281 | textreports[-train,]$text), str), 282 | 1, 283 | annotations[-train,]$stauung) } 284 | 285 | for (str in negations) { 286 | annotations[-train,]$stauung <- ifelse( 287 | str_detect( 288 | str_to_lower( 289 | textreports[-train,]$text), str), 290 | 0, 291 | annotations[-train,]$stauung) } 292 | 293 | mean(annotations[-train,]$stauung == textreports[-train, ]$Stauung, na.rm = T) 294 | ``` 295 | 296 | ## [1] 0.9304933 297 | 298 | Missed annotations on test data 299 | 300 | ``` r 301 | mean(is.na(annotations[-train,]$stauung)) 302 | ``` 303 | 304 | ## [1] 0.06694561 305 | 306 | Accuracy with missed annotations set to 307 | 0/FALSE 308 | 309 | ``` r 310 | annotations[-train,]$stauung <- ifelse(is.na(annotations[-train,]$stauung), 311 | 1, 312 | annotations[-train,]$stauung) 313 | 314 | mean(annotations[-train,]$stauung == textreports[-train, ]$Stauung, na.rm = T) 315 | ``` 316 | 317 | ## [1] 0.916318 318 | -------------------------------------------------------------------------------- /pretraining/notebooks/01_sentencizing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Preprocessing Report Texts with Spacy\n", 8 | "link to [spaCy](https://spacy.io/models/de#de_trf_bertbasecased_lg)\n", 9 | "\n", 10 | "In order to pre-train BERT, the text needs to be pre-processed. As a starting point, there were many individual text files within the report texts. Using 'R', the texts were cleaned (white-space stripping), whereby unusable reports were excluded and put into one csv-file. The structure of the csv-file is as follows: \n", 11 | "\n", 12 | "| FILE_ADDRESS_TEXT_REPORT | TEXT |\n", 13 | "|--------------------------|------|\n", 14 | "| path/to/text/report/1234_51324_8571_341_.txt | Klinik, Fragestellung, Indikation: Beispieltext für Thorax erbeten. Befund und Beurteilung: Keine Voraufnahmen. Keine Stauung, kein Erguss, kein Pneu, keine Infiltrate. Knöcherner Thorax unauffällig |\n", 15 | "| path/to/text/report/61246_523424_85245_62345_.txt | Anamnese: Weiterer Beispieltext erbeten: Befund: Keine Voraufnahmen. Hier ist ein weiterer Beispieltext. Unauffälliger Befund |\n", 16 | "| ... | ... |\n", 17 | "\n", 18 | "BERT requires a specific text format for pre-training, which can be created with scripts from [Google](https://github.com/google-research/bert/blob/master/create_pretraining_data.py). \n", 19 | "However, for the scripts to work, even the raw-input data needs a specific format. In the Google-research Git Repository it reads: \n", 20 | "\n", 21 | "> \"The input is a plain text file, with one sentence per line. (It is important that these be actual sentences for the \"next sentence prediction\" task). Documents are delimited by empty lines.\"\n", 22 | "\n", 23 | "In our case, the text data is not in the desired format, because of which it needs to be pre-processed. This can be done as specified below. \n", 24 | "\n", 25 | "\n", 26 | "## Notebook summary\n", 27 | "As different computers and operating systems were used, the code to the set-up of the environment was always provided as a first step. \n", 28 | "\n", 29 | "Two functions were defined: The `sentencizer` function, which splits each report text into sentences and the `fix_wrong_splits` function, which fixes wrong splits with the `sentencizer`. \n", 30 | "\n", 31 | "An example of wrong splitting is provided below: \n", 32 | "\n", 33 | "Original text:\n", 34 | "\n", 35 | "```python\n", 36 | "['Thorax Bedside vom 12.01.2016 Klinik, Fragestellung, Indikation: Z.n. Drainagenanlage. Frage nach Drainagenlage. Pneu? Befund und Beurteilung: Keine Voraufnahmen. 1. Kein Pneumothorax. 2. Drainage Regelrecht. 3. Zunehmende Infiltrate links. Darüber hinaus keine Befundänderung'] \n", 37 | "``` \n", 38 | " \n", 39 | "This will be splitted into: \n", 40 | "\n", 41 | "```python\n", 42 | "['Thorax Bedside vom 12.01.2016 Klinik, Fragestellung, Indikation: Z.n.',\n", 43 | " 'Drainagenanlage.',\n", 44 | " 'Frage nach Drainagenlage.',\n", 45 | " 'Pneu?',\n", 46 | " 'Befund und Beurteilung: Keine Voraufnahmen.',\n", 47 | " '1.',\n", 48 | " 'Kein Pneumothorax.',\n", 49 | " '2.',\n", 50 | " 'Drainage Regelrecht.',\n", 51 | " '3.',\n", 52 | " 'Zunehmende Infiltrate links.',\n", 53 | " 'Darüber hinaus keine Befundänderung'] \n", 54 | "``` \n", 55 | "As can be seen, this is not optimal. After using `fix_wrong_splits`, it will instead be converted into: \n", 56 | "\n", 57 | "```python\n", 58 | "['Thorax Bedside vom 12.01.2016 Klinik, Fragestellung, Indikation: Z.n. Drainagenanlage.',\n", 59 | " 'Frage nach Drainagenlage.',\n", 60 | " 'Pneu? Befund und Beurteilung: Keine Voraufnahmen.',\n", 61 | " '1. Kein Pneumothorax.',\n", 62 | " '2. Drainage Regelrecht.',\n", 63 | " '3. Zunehmende Infiltrate links.',\n", 64 | " 'Darüber hinaus keine Befundänderung']\n", 65 | "```\n", 66 | "\n", 67 | "Even though this still leaves some splits unfixed, if they appear too close after each other, it greatly improves the overall performance. \n", 68 | "Evaluation of the notebook took approximately 10 hours." 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Initializing the enviroment\n", 76 | "\n", 77 | "```bash\n", 78 | "conda create --name=text-preprocessing spacy\n", 79 | "conda activate text-preprocessing\n", 80 | "conda install ipykernel pandas\n", 81 | "ipython kernel install --user --name=spacy\n", 82 | "```" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Import packages\n", 90 | "`spacy` - workhorse for sentencizing \n", 91 | "`pandas` - for importing the csv file \n", 92 | "`time` - for monitoring time of sentencizing " 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import spacy\n", 102 | "from spacy.lang.de import German\n", 103 | "import pandas as pd\n", 104 | "import time" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "nlp = German()\n", 114 | "nlp.add_pipe(nlp.create_pipe('sentencizer')) " 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "texts = pd.read_csv('../data/cleaned-text-dump.csv', low_memory=False) " 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "def sentencizer(raw_text, nlp):\n", 133 | " doc = nlp(raw_text)\n", 134 | " sentences = [sent.string.strip() for sent in doc.sents]\n", 135 | " return(sentences)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## Fixing wrong splits\n", 143 | "Sentences with specific endings were glued together and hardcoded into an if-statement. Then 'elif' was used to check if a sentence was very short (e.g. _'1.'_ ) and in that case to also glue it to the next sentence. \n", 144 | "As the length of the document varys depending on the number of fixes, a while-loop was used instead of a for-loop. " 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "def fix_wrong_splits(sentences): \n", 154 | " i=0\n", 155 | " \n", 156 | " while i < (len(sentences)-2): \n", 157 | " if sentences[i].endswith(('Z.n.','V.a.','v.a.', 'Vd.a.' 'i.v', ' re.', \n", 158 | " ' li.', 'und 4.', 'bds.', 'Bds.', 'Pat.', \n", 159 | " 'i.p.', 'i.P.', 'b.w.', 'i.e.L.', ' pect.', \n", 160 | " 'Ggfs.', 'ggf.', 'Ggf.', 'z.B.', 'a.e.'\n", 161 | " 'I.', 'II.', 'III.', 'IV.', 'V.', 'VI.', 'VII.', \n", 162 | " 'VIII.', 'IX.', 'X.', 'XI.', 'XII.')):\n", 163 | " sentences[i:i+2] = [' '.join(sentences[i:i+2])]\n", 164 | "\n", 165 | " elif len(sentences[i]) < 10: \n", 166 | " sentences[i:i+2] = [' '.join(sentences[i:i+2])]\n", 167 | "\n", 168 | " i+=1\n", 169 | " return(sentences)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "loggingstep = []\n", 179 | "for i in range(1000): \n", 180 | " loggingstep.append(i*10000)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "raw", 185 | "metadata": {}, 186 | "source": [] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "We used the standard sentencizer from spaCy, as it perfomes similar to other natural language processing modules, such as `de_trf_bertbasecased_lg`. If more complex text-processing is required, e.g. tokenization, the `de_trf_bertbasecased_lg` natural language processing module could be used, which can be installed via: \n", 193 | "\n", 194 | "```bash\n", 195 | "conda activate text-preprocessing\n", 196 | "pip install spacy-transformers\n", 197 | "python -m spacy download de_trf_bertbasecased_lg\n", 198 | "```\n", 199 | "However, only using `de_trf_bertbasecased_lg` for sentencizing is extremely slow (aprox. 10-100 times slower), because of which ist was not used in this notebook. " 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "tic = time.clock()\n", 209 | "for i in range(len(texts)):\n", 210 | " text = texts.TEXT[i]\n", 211 | " sentences = sentencizer(text, nlp)\n", 212 | " sentences = fix_wrong_splits(sentences)\n", 213 | " with open('../data/report-dump.txt', 'a+') as file:\n", 214 | " for sent in sentences:\n", 215 | " file.write(sent + '\\n')\n", 216 | " file.write('\\n') \n", 217 | " if i in loggingstep:\n", 218 | " toc = time.clock()\n", 219 | " print('dumped the ' + str(i) + \"th report. \" + str(toc - tic) + \"seconds passed.\")\n", 220 | "toc = time.clock()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "All of the above-referenced steps may be executed by running the run-sentencizing.py file:\n", 228 | "\n", 229 | "```bash\n", 230 | "python run-sentencizing.py\n", 231 | "```" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "## Summary statistics\n", 239 | "Goal is to get extract the number of words as a word-frequency list. To split each string by words, `string.split()` can be used, but it only split by spaces and ignores special characters like colons, periods, brackets etc..\n", 240 | "A tokenizer can be used as a more robust method but this is very slow and therefore probably not worth it." 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 4, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "## count all words\n", 250 | "n = 0 \n", 251 | "file = open(r'../data/report-dump.txt', 'r', encoding=\"utf-8-sig\")\n", 252 | "for word in file.read().split():\n", 253 | " n += 1" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 20, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "data": { 263 | "text/plain": [ 264 | "54068691" 265 | ] 266 | }, 267 | "execution_count": 20, 268 | "metadata": {}, 269 | "output_type": "execute_result" 270 | } 271 | ], 272 | "source": [ 273 | "## count lines\n", 274 | "lines = 0 \n", 275 | "file = open(r'../data/report-dump.txt', 'r', encoding=\"utf-8-sig\")\n", 276 | "for line in file:\n", 277 | " lines += 1\n", 278 | "lines" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 13, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "## Count individual words of file\n", 288 | "file = open(r'../data/report-dump.txt', 'r', encoding=\"utf-8-sig\")\n", 289 | "from collections import Counter\n", 290 | "wordcount = Counter(file.read().split())" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "counts = {}\n", 300 | "counts['__Overall count__'] = []\n", 301 | "counts['__Overall count__'].append(['overall', n])\n", 302 | "counts['__individual count__'] = []\n", 303 | "for item in wordcount.items():\n", 304 | " counts['__individual count__'].append(item)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "import json\n", 314 | "with open('../statistics/word-count-report-dump.json', 'w') as outfile:\n", 315 | " json.dump(counts, outfile)" 316 | ] 317 | } 318 | ], 319 | "metadata": { 320 | "kernelspec": { 321 | "display_name": "spacy", 322 | "language": "python", 323 | "name": "spacy" 324 | }, 325 | "language_info": { 326 | "codemirror_mode": { 327 | "name": "ipython", 328 | "version": 3 329 | }, 330 | "file_extension": ".py", 331 | "mimetype": "text/x-python", 332 | "name": "python", 333 | "nbconvert_exporter": "python", 334 | "pygments_lexer": "ipython3", 335 | "version": "3.7.5" 336 | } 337 | }, 338 | "nbformat": 4, 339 | "nbformat_minor": 4 340 | } 341 | -------------------------------------------------------------------------------- /pretraining/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /finetuning/loss-logs.txt: -------------------------------------------------------------------------------- 1 | bert-base-german-cased: Epoch: 1 eval_loss: 0.37738367915153503 LRAP: 0.8854953557909468 2 | bert-base-uncased: Epoch: 1 eval_loss: 0.46002553900082904 LRAP: 0.8302707569333674 3 | bert-from-scratch: Epoch: 1 eval_loss: 0.3931102724302383 LRAP: 0.8726838597830577 4 | radiobert: Epoch: 1 eval_loss: 0.3678681658846991 LRAP: 0.8865736234373502 5 | bert-base-multi-cased: Epoch: 1 eval_loss: 0.42807909278642564 LRAP: 0.8467341031268879 6 | bert-base-german-cased: Epoch: 2 eval_loss: 0.273111719815504 LRAP: 0.9348588446734739 7 | bert-base-uncased: Epoch: 2 eval_loss: 0.35254334268115817 LRAP: 0.8991740624105351 8 | bert-from-scratch: Epoch: 2 eval_loss: 0.33306089717717396 LRAP: 0.8995104267491901 9 | radiobert: Epoch: 2 eval_loss: 0.27575823700144175 LRAP: 0.9290520695586181 10 | bert-base-multi-cased: Epoch: 2 eval_loss: 0.3477191619929813 LRAP: 0.9035753820566301 11 | bert-base-german-cased: Epoch: 3 eval_loss: 0.2418065280431793 LRAP: 0.9407908873302163 12 | bert-base-uncased: Epoch: 3 eval_loss: 0.3107446109255155 LRAP: 0.9143921970925976 13 | bert-from-scratch: Epoch: 3 eval_loss: 0.2807560857562792 LRAP: 0.9248646272136103 14 | radiobert: Epoch: 3 eval_loss: 0.21279331225724446 LRAP: 0.953917676623087 15 | bert-base-multi-cased: Epoch: 3 eval_loss: 0.26383966420378 LRAP: 0.9385347679485954 16 | bert-base-german-cased: Epoch: 4 eval_loss: 0.16241605083147684 LRAP: 0.9677653720138688 17 | bert-base-uncased: Epoch: 4 eval_loss: 0.25924643306505113 LRAP: 0.9336645513248718 18 | bert-from-scratch: Epoch: 4 eval_loss: 0.2643148849407832 LRAP: 0.9312193207276234 19 | radiobert: Epoch: 4 eval_loss: 0.17551609288368905 LRAP: 0.9625508954416767 20 | bert-base-multi-cased: Epoch: 4 eval_loss: 0.21564464271068573 LRAP: 0.9501933231542451 21 | bert-base-german-cased: Epoch: 5 eval_loss: 0.16713754779526166 LRAP: 0.9666629290326685 22 | bert-base-uncased: Epoch: 5 eval_loss: 0.2688524506631352 LRAP: 0.9272981677640992 23 | bert-from-scratch: Epoch: 5 eval_loss: 0.2452188524461928 LRAP: 0.9374443899136141 24 | radiobert: Epoch: 5 eval_loss: 0.16296553647234327 LRAP: 0.9661508731749214 25 | bert-base-multi-cased: Epoch: 5 eval_loss: 0.22625258564949036 LRAP: 0.9469061933390591 26 | bert-base-german-cased: Epoch: 6 eval_loss: 0.15501453088862555 LRAP: 0.9668124343925945 27 | bert-base-uncased: Epoch: 6 eval_loss: 0.2315209720815931 LRAP: 0.9450754683971117 28 | bert-from-scratch: Epoch: 6 eval_loss: 0.21837366159473146 LRAP: 0.9561507709296143 29 | radiobert: Epoch: 6 eval_loss: 0.15413154741483076 LRAP: 0.9698181283201325 30 | bert-base-multi-cased: Epoch: 6 eval_loss: 0.16400381230882236 LRAP: 0.965716671438114 31 | bert-base-german-cased: Epoch: 7 eval_loss: 0.15106006321452914 LRAP: 0.9700912936985081 32 | bert-base-uncased: Epoch: 7 eval_loss: 0.21289043767111643 LRAP: 0.9505527722110888 33 | bert-from-scratch: Epoch: 7 eval_loss: 0.21141792141965457 LRAP: 0.9525027606232873 34 | radiobert: Epoch: 7 eval_loss: 0.15417248853260562 LRAP: 0.970292807837898 35 | bert-base-multi-cased: Epoch: 7 eval_loss: 0.15746014263658298 LRAP: 0.96919362534593 36 | bert-base-german-cased: Epoch: 8 eval_loss: 0.15152388842155537 LRAP: 0.9727821516047971 37 | bert-base-uncased: Epoch: 8 eval_loss: 0.16133819431776092 LRAP: 0.9685995801126063 38 | bert-from-scratch: Epoch: 8 eval_loss: 0.1897809077941236 LRAP: 0.9600738779145594 39 | radiobert: Epoch: 8 eval_loss: 0.146962927299596 LRAP: 0.9699139548939147 40 | bert-base-multi-cased: Epoch: 8 eval_loss: 0.15182095047618663 LRAP: 0.9721448452460476 41 | bert-base-german-cased: Epoch: 9 eval_loss: 0.15111761367214577 LRAP: 0.9675270382033908 42 | bert-base-uncased: Epoch: 9 eval_loss: 0.15390420616382644 LRAP: 0.9711866590323502 43 | bert-from-scratch: Epoch: 9 eval_loss: 0.19475181782174678 LRAP: 0.9586678118141042 44 | radiobert: Epoch: 9 eval_loss: 0.15292102036376795 LRAP: 0.9694420587206157 45 | bert-base-multi-cased: Epoch: 9 eval_loss: 0.15269366252635205 LRAP: 0.9647968158539304 46 | bert-base-german-cased: Epoch: 10 eval_loss: 0.14237195789991391 LRAP: 0.9723406336482485 47 | bert-base-uncased: Epoch: 10 eval_loss: 0.15266327666384832 LRAP: 0.9735808124184875 48 | bert-from-scratch: Epoch: 10 eval_loss: 0.1966530332075698 LRAP: 0.9581785793809844 49 | radiobert: Epoch: 10 eval_loss: 0.15518440998026303 LRAP: 0.9726885517065877 50 | bert-base-multi-cased: Epoch: 10 eval_loss: 0.1519093079758542 LRAP: 0.9709439513948536 51 | bert-base-german-cased: Epoch: 11 eval_loss: 0.15005431337548153 LRAP: 0.9696514457486406 52 | bert-base-uncased: Epoch: 11 eval_loss: 0.15361339731940202 LRAP: 0.9685100359449055 53 | bert-from-scratch: Epoch: 11 eval_loss: 0.1894041532207103 LRAP: 0.9586643127524884 54 | radiobert: Epoch: 11 eval_loss: 0.15031190487068324 LRAP: 0.9682959569933518 55 | bert-base-multi-cased: Epoch: 11 eval_loss: 0.14860336225302445 LRAP: 0.9677443776441773 56 | bert-base-german-cased: Epoch: 12 eval_loss: 0.1545414896238418 LRAP: 0.9679345198333174 57 | bert-base-uncased: Epoch: 12 eval_loss: 0.15655041024798438 LRAP: 0.972311289245157 58 | bert-from-scratch: Epoch: 12 eval_loss: 0.21134661333191962 LRAP: 0.9549089448738743 59 | radiobert: Epoch: 12 eval_loss: 0.1627808337853778 LRAP: 0.9666993510831184 60 | bert-base-multi-cased: Epoch: 12 eval_loss: 0.15798173578722136 LRAP: 0.9655724146706113 61 | bert-base-german-cased: Epoch: 13 eval_loss: 0.15525878744111174 LRAP: 0.9688069790374402 62 | bert-base-uncased: Epoch: 13 eval_loss: 0.1576278673573619 LRAP: 0.9664257880841046 63 | bert-from-scratch: Epoch: 13 eval_loss: 0.20611229281695115 LRAP: 0.9622040679999452 64 | radiobert: Epoch: 13 eval_loss: 0.15856089039395252 LRAP: 0.9695825778541209 65 | bert-base-multi-cased: Epoch: 13 eval_loss: 0.15409610593425377 LRAP: 0.9680887966409013 66 | bert-base-german-cased: Epoch: 14 eval_loss: 0.1582624748171795 LRAP: 0.969817174030601 67 | bert-base-uncased: Epoch: 14 eval_loss: 0.15438756351137445 LRAP: 0.9651115723510515 68 | bert-from-scratch: Epoch: 14 eval_loss: 0.2014889476288642 LRAP: 0.9556420778064062 69 | radiobert: Epoch: 14 eval_loss: 0.16261111856216476 LRAP: 0.9682790183541686 70 | bert-base-multi-cased: Epoch: 14 eval_loss: 0.16319013351485842 LRAP: 0.967951776569011 71 | bert-base-german-cased: Epoch: 15 eval_loss: 0.15205234553044042 LRAP: 0.9728965868244425 72 | bert-base-uncased: Epoch: 15 eval_loss: 0.14669035610166334 LRAP: 0.9748491427299041 73 | bert-from-scratch: Epoch: 15 eval_loss: 0.20134492449107624 LRAP: 0.9624582498329993 74 | radiobert: Epoch: 15 eval_loss: 0.1707158871660275 LRAP: 0.971990409390209 75 | bert-base-multi-cased: Epoch: 15 eval_loss: 0.16853193415417558 LRAP: 0.9632596144670293 76 | bert-base-german-cased: Epoch: 16 eval_loss: 0.17020675839324081 LRAP: 0.9717931100295828 77 | bert-base-uncased: Epoch: 16 eval_loss: 0.16662229566524425 LRAP: 0.966641457518211 78 | bert-from-scratch: Epoch: 16 eval_loss: 0.20900509329069228 LRAP: 0.9577617139039982 79 | radiobert: Epoch: 16 eval_loss: 0.17589069490454026 LRAP: 0.9670525177338802 80 | bert-base-multi-cased: Epoch: 16 eval_loss: 0.15851649446856408 LRAP: 0.9686665394280624 81 | bert-base-german-cased: Epoch: 17 eval_loss: 0.16294916182579028 LRAP: 0.9714770811464197 82 | bert-base-uncased: Epoch: 17 eval_loss: 0.1581862277927853 LRAP: 0.9711868976047332 83 | bert-from-scratch: Epoch: 17 eval_loss: 0.2101698698742049 LRAP: 0.9574218277825491 84 | radiobert: Epoch: 17 eval_loss: 0.1716835054214157 LRAP: 0.9691616566466267 85 | bert-base-multi-cased: Epoch: 17 eval_loss: 0.16067363210909424 LRAP: 0.9718499697808315 86 | bert-base-german-cased: Epoch: 18 eval_loss: 0.1623041349729257 LRAP: 0.9712910742119157 87 | bert-base-uncased: Epoch: 18 eval_loss: 0.1628969760079469 LRAP: 0.9693105258135317 88 | bert-from-scratch: Epoch: 18 eval_loss: 0.2121481958421923 LRAP: 0.9618388364029649 89 | radiobert: Epoch: 18 eval_loss: 0.16712544860673093 LRAP: 0.9714473391226897 90 | bert-base-multi-cased: Epoch: 18 eval_loss: 0.17396594218111464 LRAP: 0.9701581734898366 91 | bert-base-german-cased: Epoch: 19 eval_loss: 0.17168369651993826 LRAP: 0.9709962782708272 92 | bert-base-uncased: Epoch: 19 eval_loss: 0.1646124915264192 LRAP: 0.9690655119763336 93 | bert-from-scratch: Epoch: 19 eval_loss: 0.21209525849137986 LRAP: 0.9628088717116774 94 | radiobert: Epoch: 19 eval_loss: 0.17850070579775743 LRAP: 0.9749778922925215 95 | bert-base-multi-cased: Epoch: 19 eval_loss: 0.16472308251208492 LRAP: 0.9638364029646597 96 | bert-base-german-cased: Epoch: 20 eval_loss: 0.17295368634430425 LRAP: 0.9703713776759866 97 | bert-base-uncased: Epoch: 20 eval_loss: 0.17220934285294442 LRAP: 0.9693312020867132 98 | bert-from-scratch: Epoch: 20 eval_loss: 0.22089290241932585 LRAP: 0.9591468651588888 99 | radiobert: Epoch: 20 eval_loss: 0.18175501288801788 LRAP: 0.9693824156249007 100 | bert-base-multi-cased: Epoch: 20 eval_loss: 0.16391916147300176 LRAP: 0.9698085059006903 101 | bert-base-german-cased: Epoch: 21 eval_loss: 0.16501706198877877 LRAP: 0.9716269841269841 102 | bert-base-uncased: Epoch: 21 eval_loss: 0.16373097772399584 LRAP: 0.9732027547157808 103 | bert-from-scratch: Epoch: 21 eval_loss: 0.22055147605992498 LRAP: 0.9623498584470528 104 | radiobert: Epoch: 21 eval_loss: 0.18235731021767215 LRAP: 0.9710518656360337 105 | bert-base-multi-cased: Epoch: 21 eval_loss: 0.174191768613777 LRAP: 0.9704132868912431 106 | bert-base-german-cased: Epoch: 22 eval_loss: 0.17212903515125313 LRAP: 0.9740045169704484 107 | bert-base-uncased: Epoch: 22 eval_loss: 0.17848874091924655 LRAP: 0.9672803543595127 108 | bert-from-scratch: Epoch: 22 eval_loss: 0.2177671328896568 LRAP: 0.9629411998600372 109 | radiobert: Epoch: 22 eval_loss: 0.18369798152707517 LRAP: 0.9669294143843243 110 | bert-base-multi-cased: Epoch: 22 eval_loss: 0.1676064087964949 LRAP: 0.9696685434360784 111 | bert-base-german-cased: Epoch: 23 eval_loss: 0.18019469285250775 LRAP: 0.9699307344848427 112 | bert-base-uncased: Epoch: 23 eval_loss: 0.17032801818900875 LRAP: 0.969125950313325 113 | bert-from-scratch: Epoch: 23 eval_loss: 0.22617384888941333 LRAP: 0.9619644845246046 114 | radiobert: Epoch: 23 eval_loss: 0.18968782642678844 LRAP: 0.9683074084677297 115 | bert-base-multi-cased: Epoch: 23 eval_loss: 0.18918187715600998 LRAP: 0.9688964436810129 116 | bert-base-german-cased: Epoch: 24 eval_loss: 0.18413732462518273 LRAP: 0.9724672360594203 117 | bert-base-uncased: Epoch: 24 eval_loss: 0.17867767722124145 LRAP: 0.9691266660304738 118 | bert-from-scratch: Epoch: 24 eval_loss: 0.22166201192885637 LRAP: 0.9613896046060374 119 | radiobert: Epoch: 24 eval_loss: 0.19431156206077763 LRAP: 0.9710515475395235 120 | bert-base-multi-cased: Epoch: 24 eval_loss: 0.17786366935996784 LRAP: 0.9698679899481504 121 | bert-base-german-cased: Epoch: 25 eval_loss: 0.18112100435731313 LRAP: 0.9733135318255559 122 | bert-base-uncased: Epoch: 25 eval_loss: 0.19090583419338578 LRAP: 0.9674867194706874 123 | bert-from-scratch: Epoch: 25 eval_loss: 0.22332054400993956 LRAP: 0.9598794414225275 124 | radiobert: Epoch: 25 eval_loss: 0.19171774955022902 LRAP: 0.9695645258771513 125 | bert-base-multi-cased: Epoch: 25 eval_loss: 0.19130818848498166 LRAP: 0.9665919139867036 126 | bert-base-german-cased: Epoch: 26 eval_loss: 0.1803901589786013 LRAP: 0.9725872379679994 127 | bert-base-uncased: Epoch: 26 eval_loss: 0.18585480884870603 LRAP: 0.9720601520501322 128 | bert-from-scratch: Epoch: 26 eval_loss: 0.22729128149027625 LRAP: 0.9647203136431591 129 | radiobert: Epoch: 26 eval_loss: 0.20151529865827233 LRAP: 0.9698760218850397 130 | bert-base-multi-cased: Epoch: 26 eval_loss: 0.1994623862473028 LRAP: 0.9684764767630498 131 | bert-base-german-cased: Epoch: 27 eval_loss: 0.18977484903076575 LRAP: 0.9704779400069982 132 | bert-base-uncased: Epoch: 27 eval_loss: 0.18411304443586796 LRAP: 0.9680294716416961 133 | bert-from-scratch: Epoch: 27 eval_loss: 0.23743001358317478 LRAP: 0.9608352419123958 134 | radiobert: Epoch: 27 eval_loss: 0.1989677923986511 LRAP: 0.970253443394726 135 | bert-base-multi-cased: Epoch: 27 eval_loss: 0.1986075373743439 LRAP: 0.9666333460571936 136 | bert-base-german-cased: Epoch: 28 eval_loss: 0.184378188203222 LRAP: 0.9752437414511562 137 | bert-base-uncased: Epoch: 28 eval_loss: 0.18611900178006008 LRAP: 0.9706644240862677 138 | bert-from-scratch: Epoch: 28 eval_loss: 0.2536356446244532 LRAP: 0.9594425358653816 139 | radiobert: Epoch: 28 eval_loss: 0.20101390031188549 LRAP: 0.9686336164392275 140 | bert-base-multi-cased: Epoch: 28 eval_loss: 0.19318729069172627 LRAP: 0.9695208671310874 141 | bert-base-german-cased: Epoch: 29 eval_loss: 0.19697545419469298 LRAP: 0.9715402233037503 142 | bert-base-uncased: Epoch: 29 eval_loss: 0.19614915585234052 LRAP: 0.9679107421191591 143 | bert-from-scratch: Epoch: 29 eval_loss: 0.2351191276019173 LRAP: 0.964257006075643 144 | radiobert: Epoch: 29 eval_loss: 0.2166949431023871 LRAP: 0.9716075007157169 145 | bert-base-multi-cased: Epoch: 29 eval_loss: 0.20279614426129097 LRAP: 0.9691274612717498 146 | bert-base-german-cased: Epoch: 30 eval_loss: 0.20590784007661223 LRAP: 0.9710508318223752 147 | bert-base-uncased: Epoch: 30 eval_loss: 0.19089158558996305 LRAP: 0.9713848331583803 148 | bert-from-scratch: Epoch: 30 eval_loss: 0.2409736585936376 LRAP: 0.9660377898654452 149 | radiobert: Epoch: 30 eval_loss: 0.20549893781121464 LRAP: 0.9651357476858479 150 | bert-base-multi-cased: Epoch: 30 eval_loss: 0.1889554770230981 LRAP: 0.9742850780926935 151 | bert-base-german-cased: Epoch: 31 eval_loss: 0.19395719461941294 LRAP: 0.9735629194897732 152 | bert-base-uncased: Epoch: 31 eval_loss: 0.20157154973241545 LRAP: 0.968639262652289 153 | bert-from-scratch: Epoch: 31 eval_loss: 0.24884129722513967 LRAP: 0.9646954225912141 154 | radiobert: Epoch: 31 eval_loss: 0.21918271846758822 LRAP: 0.9699072748671947 155 | bert-base-multi-cased: Epoch: 31 eval_loss: 0.21008045551189708 LRAP: 0.9682246238508764 156 | bert-base-german-cased: Epoch: 32 eval_loss: 0.19830482786277398 LRAP: 0.9688128638228838 157 | bert-base-uncased: Epoch: 32 eval_loss: 0.19991262170619198 LRAP: 0.9710949677132041 158 | bert-from-scratch: Epoch: 32 eval_loss: 0.2583969286421225 LRAP: 0.9635765181155962 159 | radiobert: Epoch: 32 eval_loss: 0.20419193635184674 LRAP: 0.970823313293253 160 | bert-base-multi-cased: Epoch: 32 eval_loss: 0.2026938864728436 LRAP: 0.9675660845500524 161 | bert-base-german-cased: Epoch: 33 eval_loss: 0.2051657560569722 LRAP: 0.9711783089989504 162 | bert-base-uncased: Epoch: 33 eval_loss: 0.20149473585410133 LRAP: 0.9691354136845121 163 | -------------------------------------------------------------------------------- /pretraining/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tokenization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("input_file", None, 31 | "Input raw text file (or comma-separated list of files).") 32 | 33 | flags.DEFINE_string( 34 | "output_file", None, 35 | "Output TF example file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("vocab_file", None, 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_bool( 41 | "do_lower_case", True, 42 | "Whether to lower case the input text. Should be True for uncased " 43 | "models and False for cased models.") 44 | 45 | flags.DEFINE_bool( 46 | "do_whole_word_mask", False, 47 | "Whether to use whole word masking rather than per-WordPiece masking.") 48 | 49 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 50 | 51 | flags.DEFINE_integer("max_predictions_per_seq", 20, 52 | "Maximum number of masked LM predictions per sequence.") 53 | 54 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 55 | 56 | flags.DEFINE_integer( 57 | "dupe_factor", 10, 58 | "Number of times to duplicate the input data (with different masks).") 59 | 60 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 61 | 62 | flags.DEFINE_float( 63 | "short_seq_prob", 0.1, 64 | "Probability of creating sequences which are shorter than the " 65 | "maximum length.") 66 | 67 | 68 | class TrainingInstance(object): 69 | """A single training instance (sentence pair).""" 70 | 71 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 72 | is_random_next): 73 | self.tokens = tokens 74 | self.segment_ids = segment_ids 75 | self.is_random_next = is_random_next 76 | self.masked_lm_positions = masked_lm_positions 77 | self.masked_lm_labels = masked_lm_labels 78 | 79 | def __str__(self): 80 | s = "" 81 | s += "tokens: %s\n" % (" ".join( 82 | [tokenization.printable_text(x) for x in self.tokens])) 83 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 84 | s += "is_random_next: %s\n" % self.is_random_next 85 | s += "masked_lm_positions: %s\n" % (" ".join( 86 | [str(x) for x in self.masked_lm_positions])) 87 | s += "masked_lm_labels: %s\n" % (" ".join( 88 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 89 | s += "\n" 90 | return s 91 | 92 | def __repr__(self): 93 | return self.__str__() 94 | 95 | 96 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 97 | max_predictions_per_seq, output_files): 98 | """Create TF example files from `TrainingInstance`s.""" 99 | writers = [] 100 | for output_file in output_files: 101 | writers.append(tf.python_io.TFRecordWriter(output_file)) 102 | 103 | writer_index = 0 104 | 105 | total_written = 0 106 | for (inst_index, instance) in enumerate(instances): 107 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 108 | input_mask = [1] * len(input_ids) 109 | segment_ids = list(instance.segment_ids) 110 | assert len(input_ids) <= max_seq_length 111 | 112 | while len(input_ids) < max_seq_length: 113 | input_ids.append(0) 114 | input_mask.append(0) 115 | segment_ids.append(0) 116 | 117 | assert len(input_ids) == max_seq_length 118 | assert len(input_mask) == max_seq_length 119 | assert len(segment_ids) == max_seq_length 120 | 121 | masked_lm_positions = list(instance.masked_lm_positions) 122 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 123 | masked_lm_weights = [1.0] * len(masked_lm_ids) 124 | 125 | while len(masked_lm_positions) < max_predictions_per_seq: 126 | masked_lm_positions.append(0) 127 | masked_lm_ids.append(0) 128 | masked_lm_weights.append(0.0) 129 | 130 | next_sentence_label = 1 if instance.is_random_next else 0 131 | 132 | features = collections.OrderedDict() 133 | features["input_ids"] = create_int_feature(input_ids) 134 | features["input_mask"] = create_int_feature(input_mask) 135 | features["segment_ids"] = create_int_feature(segment_ids) 136 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 137 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 138 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 139 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 140 | 141 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 142 | 143 | writers[writer_index].write(tf_example.SerializeToString()) 144 | writer_index = (writer_index + 1) % len(writers) 145 | 146 | total_written += 1 147 | 148 | if inst_index < 20: 149 | tf.logging.info("*** Example ***") 150 | tf.logging.info("tokens: %s" % " ".join( 151 | [tokenization.printable_text(x) for x in instance.tokens])) 152 | 153 | for feature_name in features.keys(): 154 | feature = features[feature_name] 155 | values = [] 156 | if feature.int64_list.value: 157 | values = feature.int64_list.value 158 | elif feature.float_list.value: 159 | values = feature.float_list.value 160 | tf.logging.info( 161 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 162 | 163 | for writer in writers: 164 | writer.close() 165 | 166 | tf.logging.info("Wrote %d total instances", total_written) 167 | 168 | 169 | def create_int_feature(values): 170 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 171 | return feature 172 | 173 | 174 | def create_float_feature(values): 175 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 176 | return feature 177 | 178 | 179 | def create_training_instances(input_files, tokenizer, max_seq_length, 180 | dupe_factor, short_seq_prob, masked_lm_prob, 181 | max_predictions_per_seq, rng): 182 | """Create `TrainingInstance`s from raw text.""" 183 | all_documents = [[]] 184 | 185 | # Input file format: 186 | # (1) One sentence per line. These should ideally be actual sentences, not 187 | # entire paragraphs or arbitrary spans of text. (Because we use the 188 | # sentence boundaries for the "next sentence prediction" task). 189 | # (2) Blank lines between documents. Document boundaries are needed so 190 | # that the "next sentence prediction" task doesn't span between documents. 191 | for input_file in input_files: 192 | with tf.gfile.GFile(input_file, "r") as reader: 193 | while True: 194 | line = tokenization.convert_to_unicode(reader.readline()) 195 | if not line: 196 | break 197 | line = line.strip() 198 | 199 | # Empty lines are used as document delimiters 200 | if not line: 201 | all_documents.append([]) 202 | tokens = tokenizer.tokenize(line) 203 | if tokens: 204 | all_documents[-1].append(tokens) 205 | 206 | # Remove empty documents 207 | all_documents = [x for x in all_documents if x] 208 | rng.shuffle(all_documents) 209 | 210 | vocab_words = list(tokenizer.vocab.keys()) 211 | instances = [] 212 | for _ in range(dupe_factor): 213 | for document_index in range(len(all_documents)): 214 | instances.extend( 215 | create_instances_from_document( 216 | all_documents, document_index, max_seq_length, short_seq_prob, 217 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 218 | 219 | rng.shuffle(instances) 220 | return instances 221 | 222 | 223 | def create_instances_from_document( 224 | all_documents, document_index, max_seq_length, short_seq_prob, 225 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 226 | """Creates `TrainingInstance`s for a single document.""" 227 | document = all_documents[document_index] 228 | 229 | # Account for [CLS], [SEP], [SEP] 230 | max_num_tokens = max_seq_length - 3 231 | 232 | # We *usually* want to fill up the entire sequence since we are padding 233 | # to `max_seq_length` anyways, so short sequences are generally wasted 234 | # computation. However, we *sometimes* 235 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 236 | # sequences to minimize the mismatch between pre-training and fine-tuning. 237 | # The `target_seq_length` is just a rough target however, whereas 238 | # `max_seq_length` is a hard limit. 239 | target_seq_length = max_num_tokens 240 | if rng.random() < short_seq_prob: 241 | target_seq_length = rng.randint(2, max_num_tokens) 242 | 243 | # We DON'T just concatenate all of the tokens from a document into a long 244 | # sequence and choose an arbitrary split point because this would make the 245 | # next sentence prediction task too easy. Instead, we split the input into 246 | # segments "A" and "B" based on the actual "sentences" provided by the user 247 | # input. 248 | instances = [] 249 | current_chunk = [] 250 | current_length = 0 251 | i = 0 252 | while i < len(document): 253 | segment = document[i] 254 | current_chunk.append(segment) 255 | current_length += len(segment) 256 | if i == len(document) - 1 or current_length >= target_seq_length: 257 | if current_chunk: 258 | # `a_end` is how many segments from `current_chunk` go into the `A` 259 | # (first) sentence. 260 | a_end = 1 261 | if len(current_chunk) >= 2: 262 | a_end = rng.randint(1, len(current_chunk) - 1) 263 | 264 | tokens_a = [] 265 | for j in range(a_end): 266 | tokens_a.extend(current_chunk[j]) 267 | 268 | tokens_b = [] 269 | # Random next 270 | is_random_next = False 271 | if len(current_chunk) == 1 or rng.random() < 0.5: 272 | is_random_next = True 273 | target_b_length = target_seq_length - len(tokens_a) 274 | 275 | # This should rarely go for more than one iteration for large 276 | # corpora. However, just to be careful, we try to make sure that 277 | # the random document is not the same as the document 278 | # we're processing. 279 | for _ in range(10): 280 | random_document_index = rng.randint(0, len(all_documents) - 1) 281 | if random_document_index != document_index: 282 | break 283 | 284 | random_document = all_documents[random_document_index] 285 | random_start = rng.randint(0, len(random_document) - 1) 286 | for j in range(random_start, len(random_document)): 287 | tokens_b.extend(random_document[j]) 288 | if len(tokens_b) >= target_b_length: 289 | break 290 | # We didn't actually use these segments so we "put them back" so 291 | # they don't go to waste. 292 | num_unused_segments = len(current_chunk) - a_end 293 | i -= num_unused_segments 294 | # Actual next 295 | else: 296 | is_random_next = False 297 | for j in range(a_end, len(current_chunk)): 298 | tokens_b.extend(current_chunk[j]) 299 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 300 | 301 | assert len(tokens_a) >= 1 302 | assert len(tokens_b) >= 1 303 | 304 | tokens = [] 305 | segment_ids = [] 306 | tokens.append("[CLS]") 307 | segment_ids.append(0) 308 | for token in tokens_a: 309 | tokens.append(token) 310 | segment_ids.append(0) 311 | 312 | tokens.append("[SEP]") 313 | segment_ids.append(0) 314 | 315 | for token in tokens_b: 316 | tokens.append(token) 317 | segment_ids.append(1) 318 | tokens.append("[SEP]") 319 | segment_ids.append(1) 320 | 321 | (tokens, masked_lm_positions, 322 | masked_lm_labels) = create_masked_lm_predictions( 323 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 324 | instance = TrainingInstance( 325 | tokens=tokens, 326 | segment_ids=segment_ids, 327 | is_random_next=is_random_next, 328 | masked_lm_positions=masked_lm_positions, 329 | masked_lm_labels=masked_lm_labels) 330 | instances.append(instance) 331 | current_chunk = [] 332 | current_length = 0 333 | i += 1 334 | 335 | return instances 336 | 337 | 338 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 339 | ["index", "label"]) 340 | 341 | 342 | def create_masked_lm_predictions(tokens, masked_lm_prob, 343 | max_predictions_per_seq, vocab_words, rng): 344 | """Creates the predictions for the masked LM objective.""" 345 | 346 | cand_indexes = [] 347 | for (i, token) in enumerate(tokens): 348 | if token == "[CLS]" or token == "[SEP]": 349 | continue 350 | # Whole Word Masking means that if we mask all of the wordpieces 351 | # corresponding to an original word. When a word has been split into 352 | # WordPieces, the first token does not have any marker and any subsequence 353 | # tokens are prefixed with ##. So whenever we see the ## token, we 354 | # append it to the previous set of word indexes. 355 | # 356 | # Note that Whole Word Masking does *not* change the training code 357 | # at all -- we still predict each WordPiece independently, softmaxed 358 | # over the entire vocabulary. 359 | if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and 360 | token.startswith("##")): 361 | cand_indexes[-1].append(i) 362 | else: 363 | cand_indexes.append([i]) 364 | 365 | rng.shuffle(cand_indexes) 366 | 367 | output_tokens = list(tokens) 368 | 369 | num_to_predict = min(max_predictions_per_seq, 370 | max(1, int(round(len(tokens) * masked_lm_prob)))) 371 | 372 | masked_lms = [] 373 | covered_indexes = set() 374 | for index_set in cand_indexes: 375 | if len(masked_lms) >= num_to_predict: 376 | break 377 | # If adding a whole-word mask would exceed the maximum number of 378 | # predictions, then just skip this candidate. 379 | if len(masked_lms) + len(index_set) > num_to_predict: 380 | continue 381 | is_any_index_covered = False 382 | for index in index_set: 383 | if index in covered_indexes: 384 | is_any_index_covered = True 385 | break 386 | if is_any_index_covered: 387 | continue 388 | for index in index_set: 389 | covered_indexes.add(index) 390 | 391 | masked_token = None 392 | # 80% of the time, replace with [MASK] 393 | if rng.random() < 0.8: 394 | masked_token = "[MASK]" 395 | else: 396 | # 10% of the time, keep original 397 | if rng.random() < 0.5: 398 | masked_token = tokens[index] 399 | # 10% of the time, replace with random word 400 | else: 401 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 402 | 403 | output_tokens[index] = masked_token 404 | 405 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 406 | assert len(masked_lms) <= num_to_predict 407 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 408 | 409 | masked_lm_positions = [] 410 | masked_lm_labels = [] 411 | for p in masked_lms: 412 | masked_lm_positions.append(p.index) 413 | masked_lm_labels.append(p.label) 414 | 415 | return (output_tokens, masked_lm_positions, masked_lm_labels) 416 | 417 | 418 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 419 | """Truncates a pair of sequences to a maximum sequence length.""" 420 | while True: 421 | total_length = len(tokens_a) + len(tokens_b) 422 | if total_length <= max_num_tokens: 423 | break 424 | 425 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 426 | assert len(trunc_tokens) >= 1 427 | 428 | # We want to sometimes truncate from the front and sometimes from the 429 | # back to add more randomness and avoid biases. 430 | if rng.random() < 0.5: 431 | del trunc_tokens[0] 432 | else: 433 | trunc_tokens.pop() 434 | 435 | 436 | def main(_): 437 | tf.logging.set_verbosity(tf.logging.INFO) 438 | 439 | tokenizer = tokenization.FullTokenizer( 440 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 441 | 442 | input_files = [] 443 | for input_pattern in FLAGS.input_file.split(","): 444 | input_files.extend(tf.gfile.Glob(input_pattern)) 445 | 446 | tf.logging.info("*** Reading from input files ***") 447 | for input_file in input_files: 448 | tf.logging.info(" %s", input_file) 449 | 450 | rng = random.Random(FLAGS.random_seed) 451 | instances = create_training_instances( 452 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 453 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 454 | rng) 455 | 456 | output_files = FLAGS.output_file.split(",") 457 | tf.logging.info("*** Writing to output files ***") 458 | for output_file in output_files: 459 | tf.logging.info(" %s", output_file) 460 | 461 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 462 | FLAGS.max_predictions_per_seq, output_files) 463 | 464 | 465 | if __name__ == "__main__": 466 | flags.mark_flag_as_required("input_file") 467 | flags.mark_flag_as_required("output_file") 468 | flags.mark_flag_as_required("vocab_file") 469 | tf.app.run() 470 | -------------------------------------------------------------------------------- /evaluation/accuracy.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Evaluate Model Performances" 3 | output: github_document 4 | --- 5 | 6 | ```{r setup, include=FALSE} 7 | knitr::opts_chunk$set(echo = TRUE, 8 | message = FALSE, 9 | warning = FALSE) 10 | 11 | library(tidyverse) 12 | library(magrittr) 13 | library(cowplot) 14 | library(ggradar) 15 | library(gridExtra) 16 | library(grid) 17 | ``` 18 | 19 | ## 1. Data Wrangling 20 | 21 | ### Define helper functions 22 | Function to convert the raw predictions to 0 or 1, depending on a chosen threshold 23 | ```{r conv to int} 24 | int_from_threshold <- function(.data, threshold) { 25 | mutate(.data, 26 | Stauung = as.integer(Stauung > threshold), 27 | Verschattung = as.integer(Verschattung > threshold), 28 | Erguss = as.integer(Erguss > threshold), 29 | Pneumothorax = as.integer(Pneumothorax > threshold), 30 | Thoraxdrainage = as.integer(Thoraxdrainage > threshold), 31 | ZVK = as.integer(ZVK > threshold), 32 | Magensonde = as.integer(Magensonde > threshold), 33 | Tubus = as.integer(Tubus > threshold), 34 | Materialfehllage = as.integer(Materialfehllage > threshold)) 35 | } 36 | ``` 37 | 38 | Function to convert vectors of 0 and 1 to measurements from confusion matrix 39 | ```{r cm function} 40 | confusionmatrix_measurements <- function(data) { 41 | 42 | get_x = function(pred, finding, ref=test) { 43 | ref=ref[finding] 44 | pred=pred[finding] 45 | tp <- sum(ref == pred & ref == 1) 46 | tn <- sum(ref == pred & ref == 0) 47 | fp <- sum(ref != pred & ref == 1) 48 | fn <- sum(ref != pred & ref == 0) 49 | c(tp, tn, fp, fn) 50 | } 51 | lapply(names(data), get_x, pred = data) %>% 52 | do.call(rbind, .) %>% 53 | as_tibble() %>% 54 | set_names(c("tp", "tn", "fp", "fn")) %>% 55 | mutate(finding = c("Stauung", "Verschattung", "Erguss", "Pneumothorax", "Thoraxdrainage", "ZVK", "Magensonde", "Tubus", "Materialfehllage")) 56 | } 57 | 58 | ``` 59 | 60 | Function to extract various performance measurements 61 | ```{r perf measurments} 62 | numericcharacters <- function(x) !any(is.na(suppressWarnings(as.numeric(x)))) & is.character(x) 63 | 64 | get_perf_measurements <- function(raw, threshold) { 65 | 66 | raw %>% 67 | int_from_threshold(threshold) %>% 68 | group_by(model, train_size) %>% 69 | nest() %>% 70 | mutate(data = map(data, confusionmatrix_measurements)) %>% 71 | unnest(cols = data) %>% 72 | mutate(prec = tp / (tp + fp), 73 | rec = tp / (tp + fn), 74 | spec = tn / (fp + tn), 75 | acc = (tp + tn)/(tp +fp + tn + fn), 76 | mcc = mltools::mcc(TP=tp, FP=fp, TN=tn, FN=fn)) %>% 77 | mutate(f1 = 2 * ((prec * rec) / (prec + rec)), 78 | j_stat = rec + spec -1) %>% 79 | mutate(f1 = ifelse(is.na(f1), 0, f1), 80 | finding = str_replace_all(finding, c("Stauung" = "Congestion", 81 | "Verschattung" = "Opacity", 82 | "Erguss" = "Effusion", 83 | "Tubus" = "Tracheal Tube/Cannula", 84 | "ZVK" = "Venous Catheter", 85 | "Thoraxdrainage" = "Thoracic Drain", 86 | "Magensonde" = "Gastric Tube", 87 | "Materialfehllage" = "Misplaced Medical Device"))) %>% 88 | mutate_if(numericcharacters, as.numeric) %>% 89 | ungroup() 90 | } 91 | 92 | ``` 93 | 94 | Funciton to read and format the raw results data 95 | 96 | ```{r} 97 | read_and_format_data <- function(path) { 98 | read_csv(path) %>% 99 | pivot_longer(-c(train_size, model)) %>% # similar to reshape2::melt, makes a long dataframe 100 | na.omit() %>% # some empty columns ave been added while loading the dataset 101 | select(-name) %>% 102 | mutate(value = str_remove_all(value, "\\[|\\]") %>% str_squish()) %>% # remove '[' and ']' from value string. Needs escape character to work with regrex 103 | separate(value, 104 | c("Stauung", "Verschattung", "Erguss", "Pneumothorax", "Thoraxdrainage", "ZVK", "Magensonde", "Tubus", "Materialfehllage"), 105 | sep = " ", 106 | convert = TRUE) # converts to numeric 107 | } 108 | ``` 109 | 110 | ### Load data 111 | Load the test data 112 | ```{r test data} 113 | test <- read_csv("../data/test.csv") 114 | ``` 115 | 116 | Load results from finetuning. Ignore warnings 117 | ```{r load data} 118 | raw <- read_and_format_data("../finetuning/results.csv") 119 | ``` 120 | 121 | 122 | ## 2. Plots 123 | ### Radarplots 124 | 125 | Function to simplify plotting 126 | ```{r plot fun} 127 | radar_plot_fun <- function(...,size = 1000, title="NONE", min = 0, max = 1, data=acc) { 128 | mid = mean(c(min, max)) %>% round(2) 129 | lab <- paste(c(min, mid, max)*100, "%", sep ="") 130 | 131 | filter(data, train_size == size) %>% 132 | select(c(..., "finding", "model")) %>% 133 | spread(key = finding, value =...) %>% 134 | rename(group = model) %>% 135 | select(-"Misplaced Medical Device") %>% 136 | # filter(group != "pt-radiobert-from-scratch") %>% 137 | ggradar(group.point.size = 0, centre.y = 0, grid.min = min, grid.mid = mid, grid.max = max, 138 | values.radar = lab, legend.title = "", legend.position = "bottom", plot.title = title, 139 | axis.label.size = 3, grid.label.size = 3, group.line.width = 1) + 140 | theme(legend.position = 0) + 141 | scale_color_manual(values = c("#ca0020", "#0571b0", "#f4a582", "#92c5de"), 142 | labels = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT"), 143 | breaks = c("bert-base-german-cased", 144 | "bert-base-multilingual-cased", 145 | "pt-radiobert-base-german-cased", 146 | "pt-radiobert-from-scratch")) 147 | 148 | 149 | } 150 | ``` 151 | 152 | 153 | ```{r radarplots, fig.width=12, fig.height=12} 154 | acc <- get_perf_measurements(raw, 0.5) %>% replace(is.na(.) | . < 0, 0) 155 | 156 | grid.arrange( 157 | grobs = list( 158 | radar_plot_fun("f1", title="F1-Score (200)",size=500), 159 | radar_plot_fun("j_stat", title="Youden's index",size=500), 160 | radar_plot_fun("mcc", title="MCC",size=500), 161 | 162 | radar_plot_fun("f1", title="F1-Score (1000)",size=1000), 163 | radar_plot_fun("j_stat", title="Youden's index",size=1000), 164 | radar_plot_fun("mcc", title="MCC",size=1000), 165 | 166 | radar_plot_fun("f1", title="F1-Score (2000)",size=2000), 167 | radar_plot_fun("j_stat", title="Youden's index",size=2000), 168 | radar_plot_fun("mcc", title="MCC",size=2000)), 169 | layout_matrix = rbind(1:3, 4:6, 7:9)) 170 | 171 | ``` 172 | 173 | 174 | ### Plot the loss and LRAP against epochs/steps 175 | 176 | ```{r loss-log, fig.width=12, fig.height=10} 177 | 178 | loss <- read_csv("../finetuning/epochs.csv") %>% 179 | mutate(model = ifelse(str_detect(model, "rad"), "RAD-BERT", 180 | ifelse(str_detect(model, "ger"), "GER-BERT", 181 | ifelse(str_detect(model, "fs"), "FS-BERT", 182 | ifelse(str_detect(model, "multi"), "MULTI-BERT", model) 183 | ) 184 | ) 185 | ) 186 | ) 187 | 188 | # 500 - 2000 train-size, number of steps/epoch was smaller than 50. Steps represent Epoch-Number in the dataframe. 189 | 190 | style <- list(theme(panel.grid = element_blank(), 191 | axis.line.x = element_line(size = 0.5, linetype = "solid", colour = "black"), 192 | axis.line.y = element_line(size = 0.5, linetype = "solid", colour = "black"), 193 | panel.border = element_blank(), 194 | panel.background = element_blank(), 195 | legend.title = element_blank(), 196 | legend.position = "bottom"), 197 | geom_line(size = 1), 198 | ylim(0.1, 0.7), 199 | scale_color_manual(values = c("#92c5de", "#ca0020", "#0571b0", "#f4a582"), 200 | breaks = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT"))) 201 | 202 | p_loss <- ggplot(filter(loss, data == "full"), aes(x = step, y = eval_loss, color = model)) + 203 | style + 204 | scale_x_continuous(breaks = c(588, 1176, 1764, 2352, 2940, 3528, 4116, 4704, 5292, 5880), # step reached at epoch 205 | labels = 1:10) + 206 | labs(x="Epoch",y="Loss",title = "Loss on the test dataset (full dataset)", tag= "A") 207 | 208 | legend <- get_legend(p_loss) 209 | p_loss <- p_loss + theme(legend.position = 0) 210 | 211 | p_lrap <- ggplot(filter(loss, data == "full"), aes(x = step, y = lrap, color = model)) + 212 | style + 213 | scale_x_continuous(breaks = c(588, 1176, 1764, 2352, 2940, 3528, 4116, 4704, 5292, 5880), 214 | labels = 1:10) + ylim(0.5, 1) + 215 | labs(x="Epoch",y="LRAP",title = "LRAP on the test dataset (full dataset)", tag= "B") + theme(legend.position = 0) 216 | 217 | loss_500 <- ggplot(filter(loss, data == "500"), aes(x = step, y = eval_loss, color = model)) + 218 | style + theme(legend.position = 0) + scale_x_continuous(breaks = 1:10) + 219 | labs(x="Epoch",y="Loss",title = "Evaluation loss - train size 500", tag= "C") 220 | 221 | loss_1000 <- ggplot(filter(loss, data == "1000"), aes(x = step, y = eval_loss, color = model)) + 222 | style + theme(legend.position = 0) + scale_x_continuous(breaks = 1:10) + 223 | labs(x="Epoch",y="Loss",title = "Evaluation loss - train size 1000", tag= "D") 224 | 225 | loss_2000 <- ggplot(filter(loss, data == "2000"), aes(x = step, y = eval_loss, color = model)) + 226 | style + theme(legend.position = 0) + scale_x_continuous(breaks = 1:10) + 227 | labs(x="Epoch",y="Loss",title = "Evaluation loss - train size 2000", tag= "E") 228 | 229 | ggdraw( 230 | plot_grid( 231 | plot_grid(p_loss, p_lrap, ncol=2, align='v'), 232 | plot_grid(loss_500, loss_1000, loss_2000, ncol = 3, nrow = 1, rel_widths = c(1, 1, 1)), 233 | plot_grid(NA, legend, NA, ncol = 3, nrow = 1, rel_widths = c(0.1, 1, 0.1)), 234 | nrow = 3, 235 | rel_heights=c(1,0.7,0.2))) 236 | ``` 237 | 238 | ### Plot Accuray on different train-size 239 | 240 | ```{r f1, fig.width=12, fig.height=12} 241 | style <- list(theme(panel.grid = element_blank(), 242 | axis.line.x = element_line(size = 0.5, linetype = "solid", colour = "black"), 243 | axis.line.y = element_line(size = 0.5, linetype = "solid", colour = "black"), 244 | panel.border = element_blank(), 245 | panel.background = element_blank(), 246 | legend.title = element_blank(), 247 | legend.position = "bottom"), 248 | geom_line(size = 1), 249 | scale_color_manual(values = c("#ca0020", "#0571b0", "#f4a582", "#92c5de"), 250 | labels = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT"))) 251 | 252 | 253 | ggplot(acc, aes(x=train_size, y=f1, color=model)) + 254 | style + 255 | facet_wrap(~finding) + 256 | xlim(100,4000) + 257 | ggtitle("F1 Score") 258 | ``` 259 | ```{r acc, fig.width=12, fig.height=12} 260 | ggplot(acc, aes(x=train_size, y=acc, color=model)) + 261 | style + 262 | facet_wrap(~finding) + 263 | xlim(100,4000) + 264 | ggtitle("Accuracy Score") 265 | ``` 266 | 267 | ```{r mcc, fig.height=12, fig.width=12} 268 | 269 | ggplot(acc, aes(x=train_size, y=mcc, color=model)) + 270 | style + 271 | facet_wrap(~finding) + 272 | xlim(100,4000) + 273 | ggtitle("MCC Score") 274 | 275 | ``` 276 | 277 | ```{r plot pooled, fig.width=12, fig.height=5} 278 | acc %>% 279 | select(train_size, model, f1, mcc, j_stat) %>% 280 | group_by(train_size, model) %>% 281 | nest() %>% 282 | mutate(data = map(data, mutate_if, is.numeric, mean)) %>% 283 | unnest() %>% 284 | ungroup() %>% 285 | filter(!duplicated(.)) %>% 286 | pivot_longer(-c(train_size, model)) %>% 287 | mutate(name = str_replace_all(name, 288 | c("f1|j_stat|mcc"), 289 | c("F1 Score", "Youden's statistik", "MCC"))) %>% 290 | ggplot(aes(x=train_size, y=value, color=model)) + 291 | style + 292 | facet_wrap(~name) + 293 | xlim(100,4500) + 294 | scale_y_continuous(breaks = 0:10/10) + 295 | labs(x="Size of train dataset", title = "Performance Gain on Increasing Train Data") + 296 | theme(axis.title.y = element_blank()) 297 | ``` 298 | 299 | ### ROC Curves and Precision-Recall Curves 300 | 301 | ```{r, fig.width=12, fig.height=12} 302 | get_values <- function(threshold, size, raw) { 303 | filter(raw, train_size == size) %>% 304 | get_perf_measurements(threshold) %>% 305 | select(c(finding, model, prec, rec, spec)) %>% 306 | mutate(thrshld = threshold) 307 | } 308 | 309 | auc <- function(sen, spe){ 310 | sen = sort(sen) 311 | spe = sort(1-spe) 312 | dsen <- c(diff(sen), 0) 313 | dspe <- c(diff(spe), 0) 314 | sum(sen * dspe) + sum(dsen * dspe)/2 315 | } 316 | 317 | 318 | lapply(0:15/15, get_values, size = 4703, raw=raw) %>% 319 | do.call(rbind, .) %>% 320 | replace_na(list(spec = 1, rec = 1)) %>% 321 | rbind(filter(., thrshld == 0) %>% mutate(prec = 0, spec = 0, rec = 1), # extrapolate Plotlines 322 | filter(., thrshld == 0) %>% mutate(prec = 1, spec = 1, rec = 0)) -> plot_df 323 | ``` 324 | 325 | ```{r roc, fig.width=12, fig.height=12} 326 | plot_fun <- function(FINDING, df, type="AUC") { 327 | df %<>% filter(finding == FINDING) 328 | 329 | add <- list(theme_cowplot(), 330 | theme(legend.position = 0), 331 | scale_color_manual(values = c("#ca0020", "#0571b0", "#f4a582", "#92c5de"), 332 | labels = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT"))) 333 | 334 | if (type == "AUC") { 335 | main <- 336 | ggplot(df, aes(x = 1-spec, y = rec, color = model)) + 337 | labs(title = FINDING, y = "Sensitivity", x = "1-Specificty") + 338 | add} 339 | 340 | if (type == "AUPRC") { 341 | main <- 342 | ggplot(df, aes(x = 1-prec, y = rec, color = model)) + 343 | labs(title = FINDING, y = "Recall", x = "1-Precisison") + 344 | add} 345 | 346 | sub <- 347 | main + geom_step(direction = "vh", size = 2, alpha = 0.5) + 348 | coord_cartesian(ylim=c(0.75,1), xlim=c(0, 0.25)) + 349 | theme_nothing() 350 | 351 | main = main + 352 | geom_rect(xmin=-0.025,xmax=0.25,ymin=0.75,ymax=1.025, fill = "#efedf5", linetype = 1, color = "#000000") + 353 | geom_rect(xmin=0.3,xmax=1.02,ymin=-0.03,ymax=0.75, fill = "#efedf5", linetype = 1, color = "#000000") + 354 | geom_step(direction = "vh", size = 1, alpha = 0.5) 355 | 356 | 357 | ggdraw() + 358 | draw_plot(main) + 359 | draw_plot(sub, x = 0.45, y = 0.17, width = .5, height = .5) 360 | } 361 | 362 | grid.arrange( 363 | grobs = lapply(unique(plot_df$finding), plot_fun, plot_df, "AUC"), 364 | layout_matrix = matrix(1:9, 3), 365 | top=textGrob("Receiver Operating Characteristic Curves \n",gp=gpar(fontsize=20)) 366 | ) 367 | 368 | ``` 369 | ```{r prc, fig.width=12, fig.height=12} 370 | l <- lapply(unique(plot_df$finding)[1:8], plot_fun, plot_df, "AUPRC") 371 | l[[9]] <- plot_df %>% 372 | filter(finding == "Misplaced Medical Device") %>% 373 | ggplot(aes(x = 1-prec, y = rec, color = model)) + 374 | labs(title = "Misplaced Medical Device", y = "Recall", x = "1-Precisison") + 375 | theme_cowplot() + 376 | theme(legend.position = 0) + 377 | scale_color_manual(values = c("#ca0020", "#0571b0", "#f4a582", "#92c5de"), 378 | labels = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT")) + 379 | geom_step(direction = "vh", size = 1, alpha = 0.5) 380 | 381 | grid.arrange( 382 | grobs = l, 383 | layout_matrix = matrix(1:9, 3), 384 | top=textGrob("Precision-Recall Curves\n",gp=gpar(fontsize=20)) 385 | ) 386 | 387 | ``` 388 | 389 | ## 3. Tables 390 | ### AUC 391 | ```{r} 392 | writeLines("train_size,Model,Finding,AUPRC,AUC", "auc.csv") 393 | 394 | loop_auc <- function(df, train_size, out_file) { 395 | for (i in unique(df$model)) { 396 | for (j in unique(df$finding)) { 397 | 398 | x = filter(df, model == i & finding == j) %>% select(rec) %>% unlist() 399 | y = filter(df, model == i & finding == j) %>% select(prec) %>% unlist() 400 | z = filter(df, model == i & finding == j) %>% select(spec) %>% unlist() 401 | 402 | 403 | paste(train_size, 404 | i, 405 | j, 406 | auc(x, y) %>% round(2), 407 | auc(x, z) %>% round(2), 408 | sep=",") %>% paste("\n", sep="") %>% cat(file=out_file, append = T) 409 | } 410 | } 411 | } 412 | 413 | for (size in c(200, 400, 600, 800, 1000, 1500, 2000, 4000, 4703)) { 414 | 415 | lapply(0:15/15, get_values, size = size, raw=raw) %>% 416 | do.call(rbind, .) %>% 417 | replace_na(list(spec = 1, rec = 1)) %>% 418 | rbind(filter(., thrshld == 0) %>% mutate(prec = 0, spec = 0, rec = 1), # extrapolate Plotlines 419 | filter(., thrshld == 0) %>% mutate(prec = 1, spec = 1, rec = 0)) %>% 420 | loop_auc(size, out_file="auc.csv") 421 | } 422 | 423 | 424 | read_csv("auc.csv") 425 | ``` 426 | ## 4. Performance on long texts 427 | 428 | ```{r prc long, fig.width=6, fig.height=6} 429 | test <- read_csv("../data/ct.csv") 430 | raw <-read_and_format_data("../finetuning/results-long-text.csv") 431 | acc <- get_perf_measurements(raw, 0) %>% replace(is.na(.) | . < 0, 0) 432 | 433 | radar_plot_fun("mcc", title="mcc",size=4000) + 434 | theme(legend.position = "bottom") + 435 | scale_color_manual(values = c("#92c5de", "#ca0020", "#0571b0", "#f4a582"), 436 | labels = c("GER-BERT", "MULTI-BERT", "RAD-BERT", "FS-BERT"), 437 | breaks = c("gerbert", "multibert", "radbert", "fsbert")) + 438 | ggtitle("MCC for CT-Reports") 439 | ``` 440 | 441 | ```{r} 442 | writeLines("train_size,Model,Finding,AUPRC,AUC", "auc-for-ct.csv") 443 | 444 | lapply(0:15/15, get_values, size = 4000, raw=raw) %>% 445 | do.call(rbind, .) %>% 446 | replace_na(list(spec = 1, rec = 1)) %>% 447 | rbind(filter(., thrshld == 0) %>% mutate(prec = 0, spec = 0, rec = 1), # extrapolate Plotlines 448 | filter(., thrshld == 0) %>% mutate(prec = 1, spec = 1, rec = 0)) %>% 449 | loop_auc(size,out_file = "auc-for-ct.csv") 450 | 451 | read_csv("auc-for-ct.csv") 452 | ``` 453 | 454 | 455 | 456 | 457 | 458 | --------------------------------------------------------------------------------