├── data ├── transformer ├── __init__.py ├── 05_evaluate.sh ├── 03_train_full.sh ├── 04_train_contexts.sh ├── 07_official_metrics.sh ├── define_problem.py ├── alternative │ └── define_problem.py ├── 06_decode.sh ├── 02_count_subwords.ipynb └── 01_generate_data.ipynb ├── transformer_moe ├── __init__.py ├── 04_evaluate.sh ├── 03_train_full.sh ├── 06_official_metrics.sh ├── define_problem.py ├── alternative │ └── define_problem.py ├── 05_decode.sh ├── 02_count_subwords.ipynb └── 01_generate_data.ipynb ├── gpt2 ├── 02_encode.sh ├── 03_train.sh ├── README.md ├── 04_decode.sh ├── 05_official_metrics.sh ├── generate_conditional_samples.py └── 01_preprocess.ipynb ├── transformer-xl ├── 02_train.sh ├── 01_preprocess.ipynb ├── run_lm_finetuning.py └── 03_decode.ipynb ├── README.md ├── ROUGE.md ├── low_resource └── 01_format.ipynb ├── environment.yml ├── analysis ├── scispacy.ipynb └── discharge_summary.ipynb └── preprocess └── 04_context_data_split.ipynb /data: -------------------------------------------------------------------------------- 1 | /mimic/data/ -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import define_problem 2 | -------------------------------------------------------------------------------- /transformer_moe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import define_problem 2 | -------------------------------------------------------------------------------- /gpt2/02_encode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA=$HOME/project/text-generation/data/gpt2/low_resource 4 | 5 | cd ./gpt-2 6 | 7 | PYTHONPATH=src ./encode.py $DATA/input-text.txt $DATA/input-text.txt.npz 8 | 9 | PYTHONPATH=src ./encode.py $DATA/val-input-text.txt $DATA/val-input-text.txt.npz -------------------------------------------------------------------------------- /gpt2/03_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | DATA=$HOME/project/text-generation/data/gpt2/low_resource 6 | 7 | cd ./gpt-2 8 | 9 | PYTHONPATH=src ./train.py --dataset $DATA/input-text.txt.npz --val_dataset $DATA/val-input-text.txt.npz --sample_every=10000 --save_every=10000 --optimizer='adam' --model_name=117M --val_every=1000 --batch_size=2 10 | -------------------------------------------------------------------------------- /gpt2/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | Firstly clone the gpt-2 repo and download the small model: 4 | 5 | ``` 6 | git clone https://github.com/nshepperd/gpt-2 7 | cd gpt-2 8 | python download_model.py 117M 9 | ``` 10 | 11 | Returning back to the directory with the scripts: 12 | 13 | 1. Run Preprocessing Notebook 14 | 2. Run encoding bash script: `./02_encode.sh` 15 | 3. Run training bash script: `nohup ./03_train.sh > train.out &` -------------------------------------------------------------------------------- /transformer-xl/02_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | DATA_DIR=$HOME/project/text-generation/data/transformer-xl/low_resource 6 | OUT_DIR=$DATA_DIR/output 7 | 8 | mkdir $OUT_DIR 9 | 10 | python run_lm_finetuning.py --train_data_file $DATA_DIR/input-text.txt --output_dir $OUT_DIR --model_type 'transfo-xl-wt103' --model_name_or_path 'transfo-xl-wt103' --do_train --block_size 512 --per_gpu_train_batch_size 1 --num_train_epochs 4 --logging_steps 250 --save_steps 1000 11 | -------------------------------------------------------------------------------- /gpt2/04_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | cd ./gpt-2 6 | 7 | MODEL='run1' 8 | DATA_DIR=../../data/gpt2/low_resource 9 | INPUT_FILE=$DATA_DIR/test-input-text.txt 10 | OUTPUT_FILE=$DATA_DIR/test-output-text.txt 11 | TOP_K=0 12 | TEMP=1 13 | BS=1 14 | 15 | ln -sf ../../models/117M/encoder.json checkpoint/$MODEL/ 16 | ln -sf ../../models/117M/hparams.json checkpoint/$MODEL/ 17 | ln -sf ../../models/117M/vocab.bpe checkpoint/$MODEL/ 18 | 19 | ln -sf ../checkpoint/$MODEL/ models/ 20 | 21 | python src/generate_conditional_samples.py --input_file $INPUT_FILE --output_file $OUTPUT_FILE --model_name $MODEL --top_k $TOP_K --temperature $TEMP --batch_size $BS 22 | -------------------------------------------------------------------------------- /transformer_moe/04_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | BASE=$HOME/project/text-generation # change this as necessary 6 | 7 | PROBLEM=mimic_discharge_summaries 8 | MODEL=transformer_moe 9 | HPARAMS=transformer_moe_base 10 | DATA_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/data 11 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/output 12 | USR_DIR=$BASE/transformer_moe 13 | 14 | t2t-trainer \ 15 | --t2t_usr_dir=$USR_DIR \ 16 | --problem=$PROBLEM \ 17 | --model=$MODEL \ 18 | --hparams_set=$HPARAMS \ 19 | --data_dir=$DATA_DIR \ 20 | --output_dir=$TRAIN_DIR \ 21 | --eval_use_test_set=True \ 22 | --eval_steps=1000 \ 23 | --schedule=evaluate \ 24 | --worker_gpu=4 25 | -------------------------------------------------------------------------------- /transformer/05_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | BASE=$HOME/project/text-generation # change this as necessary 6 | 7 | PROBLEM=mimic_discharge_summaries 8 | MODEL=transformer 9 | HPARAMS=transformer_base 10 | DATA_DIR=$BASE/data/t2t_experiments/transformer/low_resource/full_context/data 11 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer/low_resource/full_context/output 12 | USR_DIR=$BASE/transformer 13 | 14 | t2t-trainer \ 15 | --t2t_usr_dir=$USR_DIR \ 16 | --problem=$PROBLEM \ 17 | --model=$MODEL \ 18 | --hparams_set=$HPARAMS \ 19 | --data_dir=$DATA_DIR \ 20 | --output_dir=$TRAIN_DIR \ 21 | --eval_use_test_set=True \ 22 | --eval_steps=100 \ 23 | --schedule=evaluate \ 24 | --worker_gpu=4 25 | -------------------------------------------------------------------------------- /transformer_moe/03_train_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | BASE=$HOME/project/text-generation # change this as necessary 6 | 7 | # transformer DMCA 8 | MODEL=transformer_moe 9 | HPARAMS=transformer_moe_base 10 | HPARAMS_OVERRIDE="max_length=10000,max_target_seq_length=512,max_input_seq_length=512" 11 | 12 | PROBLEM=mimic_discharge_summaries 13 | DATA_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/data 14 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/output 15 | USR_DIR=$BASE/transformer_moe 16 | 17 | t2t-trainer \ 18 | --data_dir=$DATA_DIR \ 19 | --problem=$PROBLEM \ 20 | --model=$MODEL \ 21 | --hparams_set=$HPARAMS \ 22 | --hparams=$HPARAMS_OVERRIDE \ 23 | --output_dir=$TRAIN_DIR \ 24 | --t2t_usr_dir=$USR_DIR \ 25 | --eval_steps=100 \ 26 | --local_eval_frequency=200 \ 27 | --train_steps=5000 \ 28 | --worker_gpu=4 29 | -------------------------------------------------------------------------------- /transformer/03_train_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | BASE=$HOME/project/text-generation # change this as necessary 6 | 7 | # transformer 8 | MODEL=transformer 9 | HPARAMS=transformer_base 10 | HPARAMS_OVERRIDE="max_length=10000,max_target_seq_length=512,max_input_seq_length=512" 11 | 12 | PROBLEM=mimic_discharge_summaries 13 | DATA_DIR=$BASE/data/t2t_experiments/$MODEL/low_resource/full_context/data 14 | TRAIN_DIR=$BASE/data/t2t_experiments/$MODEL/low_resource/full_context/output 15 | USR_DIR=$BASE/transformer 16 | 17 | t2t-trainer \ 18 | --data_dir=$DATA_DIR \ 19 | --problem=$PROBLEM \ 20 | --model=$MODEL \ 21 | --hparams_set=$HPARAMS \ 22 | --hparams=$HPARAMS_OVERRIDE \ 23 | --output_dir=$TRAIN_DIR \ 24 | --t2t_usr_dir=$USR_DIR \ 25 | --eval_steps=50 \ 26 | --local_eval_frequency=200 \ 27 | --train_steps=4000 \ 28 | --worker_gpu=4 \ 29 | --warm_start_from=$TRAIN_DIR 30 | -------------------------------------------------------------------------------- /transformer/04_train_contexts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | BASE=$HOME/project # change this as necessary 5 | 6 | PROBLEM=mimic_discharge_summaries 7 | MODEL=transformer 8 | HPARAMS=transformer_base 9 | USR_DIR=$BASE/t2t 10 | 11 | for i in {h,h-gae,h-gae-d,h-gae-p,h-gae-d-p,h-gae-d-p-m,h-gae-d-p-m-t,h-gae-d-p-m-l} 12 | do 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | DATA_DIR=$BASE/data/t2t_experiments/other_contexts/$i/data 15 | TRAIN_DIR=$BASE/data/t2t_experiments/other_contexts/$i/output 16 | 17 | t2t-trainer \ 18 | --data_dir=$DATA_DIR \ 19 | --problem=$PROBLEM \ 20 | --model=$MODEL \ 21 | --hparams_set=$HPARAMS \ 22 | --hparams="max_length=10000,max_target_seq_length=512,max_input_seq_length=512" \ 23 | --output_dir=$TRAIN_DIR \ 24 | --t2t_usr_dir=$USR_DIR \ 25 | --train_steps=5000 \ 26 | --eval_steps=50 \ 27 | --worker_gpu=4 28 | # --warm_start_from=$TRAIN_DIR 29 | done -------------------------------------------------------------------------------- /gpt2/05_official_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE=$HOME/project/text-generation # change this as necessary 4 | PREPROC=$BASE/data/preprocessed/low_resource 5 | TRAIN_DIR=$BASE/data/gpt2/low_resource 6 | 7 | mkdir $PREPROC/gold 8 | ln -s $PREPROC/tgt-test.txt $PREPROC/gold/tgt-test.A.001.txt 9 | GOLD=$PREPROC/gold/tgt-test.A.001.txt 10 | 11 | mkdir $TRAIN_DIR/gpt2_decoded 12 | ln -s $TRAIN_DIR/test-output-text.txt $TRAIN_DIR/gpt2_decoded/tgt-test.001.txt 13 | PREDICTION=$TRAIN_DIR/gpt2_decoded/tgt-test.001.txt 14 | 15 | wc -l $PREDICTION 16 | wc -l $GOLD 17 | 18 | # Evaluate the official ROUGE score 19 | # Note: Report this ROUGE score in papers, not the internal approx_rouge metric. 20 | 21 | pyrouge_evaluate_plain_text_files -s $TRAIN_DIR/gpt2_decoded/ -sfp "tgt-test.(\d+).txt" -m $PREPROC/gold/ -mfp tgt-test.[A-Z].#ID#.txt > $TRAIN_DIR/gpt2_decoded/rouge.txt 22 | 23 | # Evaluate the official BLEU score 24 | # Note: Report this BLEU score in papers, not the internal approx_bleu metric. 25 | 26 | t2t-bleu --translation=$PREDICTION --reference=$GOLD > $TRAIN_DIR/gpt2_decoded/bleu.txt 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mimic-text-generation 2 | 3 | This repository contains the code to recreate all our text generation models. This is part of my Master's thesis researching transformer-based Natural Language Generation methods for NLP data augmentation, specifically in the medical field. We use the MIMIC-III dataset. 4 | 5 | Follow the steps below: 6 | 7 | 8 | 1. Install the environment 9 | ``` 10 | conda env create -f environment.yml 11 | ``` 12 | 2. Run the notebooks in the `preprocess` directory sequentially. This assumes you have already downloaded the MIMIC-III dataset and configured it into a Postgres database. If not follow the instructions on https://mimic.physionet.org 13 | 3. Ensure you install ROUGE separately. Instructions provided in `ROUGE.md` 14 | 4. For each model e.g. `transformer`, enter the directory and run the notebooks or Python files sequentially. These directories contain all the necessary files for further preprocessing, training, decoding and evaluating each individual mode. Some files may need 1 or 2 variables changed to switch between the low resource setting and the full resource setting 15 | 5. Voila 16 | -------------------------------------------------------------------------------- /transformer_moe/06_official_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE=$HOME/project/text-generation # change this as necessary 4 | PREPROC=$BASE/data/preprocessed 5 | TRAIN_DIR=$BASE/data/t2t_experiments/full_context/output 6 | 7 | mkdir $PREPROC/gold 8 | ln -s $PREPROC/tgt-test.txt $PREPROC/gold/tgt-test.A.001.txt 9 | GOLD=$PREPROC/gold/tgt-test.A.001.txt 10 | mkdir $TRAIN_DIR/transformer_decoded 11 | PREDICTION=$TRAIN_DIR/transformer_decoded/tgt-test.001.txt 12 | 13 | rm $PREPROC/temp_* 14 | cat $TRAIN_DIR/tgt-decoded-0.txt $TRAIN_DIR/tgt-decoded-1.txt $TRAIN_DIR/tgt-decoded-2.txt $TRAIN_DIR/tgt-decoded-3.txt > $PREDICTION 15 | rm $TRAIN_DIR/tgt-decoded* 16 | 17 | # Remove blank lines 18 | 19 | wc -l $PREDICTION 20 | sed -i '/^$/d' $PREDICTION 21 | wc -l $PREDICTION 22 | wc -l $GOLD 23 | 24 | # Evaluate the official ROUGE score 25 | # Note: Report this ROUGE score in papers, not the internal approx_rouge metric. 26 | 27 | pyrouge_evaluate_plain_text_files -s $TRAIN_DIR/transformer_decoded/ -sfp "tgt-test.(\d+).txt" -m $PREPROC/gold/ -mfp tgt-test.[A-Z].#ID#.txt > $TRAIN_DIR/transformer_decoded/rouge.txt 28 | 29 | # Evaluate the official BLEU score 30 | # Note: Report this BLEU score in papers, not the internal approx_bleu metric. 31 | 32 | t2t-bleu --translation=$PREDICTION --reference=$GOLD > $TRAIN_DIR/transformer_decoded/bleu.txt -------------------------------------------------------------------------------- /ROUGE.md: -------------------------------------------------------------------------------- 1 | ### Rouge installation 2 | 3 | First we install ROUGE-1.5.5 - the definitive ROUGE implementation: 4 | ``` 5 | sudo apt install subversion 6 | svn checkout https://github.com/andersjo/pyrouge/trunk/tools/ROUGE-1.5.5 7 | sudo cpan App::cpanminus 8 | sudo cpanm XML::DOM 9 | ROUGE_EVAL_HOME=/absolute/path/to/ROUGE-1.5.5/data/ 10 | export ROUGE_EVAL_HOME 11 | ``` 12 | 13 | Next, we setup the easy-to-use `pyrouge` python wrapper for ROUGE-1.5.5 and test it to ensure it is working properly: 14 | ``` 15 | git clone https://github.com/bheinzerling/pyrouge 16 | cd pyrouge 17 | python setup.py install 18 | pyrouge_set_rouge_path /absolute/path/to/ROUGE-1.5.5/ 19 | python -m pyrouge.test 20 | ``` 21 | 22 | If the test passes, you should see something like: 23 | 24 | ``` 25 | Ran 11 tests in 6.322s 26 | OK 27 | ``` 28 | 29 | A common occurrence is that it fails with the `"Cannot open exception db file for reading: data/WordNet-2.0.exc.db"` error message. If this is the case, follow the instructions below: 30 | ``` 31 | cd /absolute/path/to/ROUGE-1.5.5/ 32 | 33 | cd data/WordNet-2.0-Exceptions/ 34 | rm WordNet-2.0.exc.db # only if it exists 35 | ./buildExeptionDB.pl . exc WordNet-2.0.exc.db 36 | 37 | cd ../ 38 | rm WordNet-2.0.exc.db # only if it exists 39 | ln -s WordNet-2.0-Exceptions/WordNet-2.0.exc.db WordNet-2.0.exc.db 40 | ``` 41 | -------------------------------------------------------------------------------- /transformer/07_official_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE=$HOME/project/text-generation # change this as necessary 4 | PREPROC=$BASE/data/preprocessed/low_resource 5 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer/low_resource/full_context/output 6 | 7 | mkdir $PREPROC/gold 8 | ln -s $PREPROC/tgt-test.txt $PREPROC/gold/tgt-test.A.001.txt 9 | GOLD=$PREPROC/gold/tgt-test.A.001.txt 10 | mkdir $TRAIN_DIR/transformer_decoded 11 | PREDICTION=$TRAIN_DIR/transformer_decoded/tgt-test.001.txt 12 | 13 | rm $PREPROC/temp_* 14 | cat $TRAIN_DIR/tgt-decoded-0.txt $TRAIN_DIR/tgt-decoded-1.txt $TRAIN_DIR/tgt-decoded-2.txt $TRAIN_DIR/tgt-decoded-3.txt > $PREDICTION 15 | #rm $TRAIN_DIR/tgt-decoded* 16 | 17 | # Remove blank lines 18 | 19 | wc -l $PREDICTION 20 | sed -i '/^$/d' $PREDICTION 21 | wc -l $PREDICTION 22 | wc -l $GOLD 23 | 24 | # Evaluate the official ROUGE score 25 | # Note: Report this ROUGE score in papers, not the internal approx_rouge metric. 26 | 27 | pyrouge_evaluate_plain_text_files -s $TRAIN_DIR/transformer_decoded/ -sfp "tgt-test.(\d+).txt" -m $PREPROC/gold/ -mfp tgt-test.[A-Z].#ID#.txt > $TRAIN_DIR/transformer_decoded/rouge.txt 28 | 29 | # Evaluate the official BLEU score 30 | # Note: Report this BLEU score in papers, not the internal approx_bleu metric. 31 | 32 | t2t-bleu --translation=$PREDICTION --reference=$GOLD > $TRAIN_DIR/transformer_decoded/bleu.txt 33 | -------------------------------------------------------------------------------- /transformer/define_problem.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #tf.enable_eager_execution() 3 | from tensor2tensor.utils import trainer_lib 4 | RANDOM_SEED = 301 5 | trainer_lib.set_random_seed(RANDOM_SEED) 6 | from tensor2tensor.data_generators import problem 7 | from tensor2tensor.data_generators import text_problems 8 | from tensor2tensor.utils import registry 9 | 10 | 11 | @registry.register_problem 12 | 13 | class MimicDischargeSummaries(text_problems.Text2TextProblem): 14 | 15 | @property 16 | def is_generate_per_split(self): 17 | # our data already has pre-existing splits 18 | return True 19 | 20 | def generate_samples(self, data_dir, tmp_dir, dataset_split): 21 | 22 | del tmp_dir 23 | 24 | directory = "data/preprocessed/" 25 | 26 | _train = (dataset_split == problem.DatasetSplit.TRAIN) 27 | _eval = (dataset_split == problem.DatasetSplit.EVAL) 28 | 29 | dataset = "train" if _train else "val" if _eval else "test" 30 | 31 | src = directory + "src-" + dataset + ".txt" 32 | tgt = directory + "tgt-" + dataset + ".txt" 33 | 34 | f_src = open(src,'r') 35 | f_tgt = open(tgt,'r') 36 | 37 | context_data = f_src.readline() 38 | discharge_summary = f_tgt.readline() 39 | 40 | while context_data: 41 | yield { 42 | "inputs" : context_data, 43 | "targets" : discharge_summary, 44 | } 45 | 46 | context_data = f_src.readline() 47 | discharge_summary = f_tgt.readline() 48 | 49 | f_src.close() 50 | f_tgt.close() 51 | 52 | @property 53 | def vocab_type(self): 54 | # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good 55 | # tradeoff between CHARACTER and TOKEN. 56 | return text_problems.VocabType.SUBWORD 57 | 58 | @property 59 | def approx_vocab_size(self): 60 | # Approximate vocab size to generate. Only for VocabType.SUBWORD. 61 | return 2**15 # ~32k - this is the default setting 62 | 63 | @property 64 | def dataset_splits(self): 65 | return [{ 66 | "split": problem.DatasetSplit.TRAIN, 67 | "shards": 80 68 | }, { 69 | "split": problem.DatasetSplit.EVAL, 70 | "shards": 10 71 | }, { 72 | "split": problem.DatasetSplit.TEST, 73 | "shards": 10 74 | }] 75 | -------------------------------------------------------------------------------- /transformer_moe/define_problem.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #tf.enable_eager_execution() 3 | from tensor2tensor.utils import trainer_lib 4 | RANDOM_SEED = 301 5 | trainer_lib.set_random_seed(RANDOM_SEED) 6 | from tensor2tensor.data_generators import problem 7 | from tensor2tensor.data_generators import text_problems 8 | from tensor2tensor.utils import registry 9 | 10 | 11 | @registry.register_problem 12 | 13 | class MimicDischargeSummaries(text_problems.Text2TextProblem): 14 | 15 | @property 16 | def is_generate_per_split(self): 17 | # our data already has pre-existing splits 18 | return True 19 | 20 | def generate_samples(self, data_dir, tmp_dir, dataset_split): 21 | 22 | del tmp_dir 23 | 24 | directory = "data/preprocessed/" 25 | 26 | _train = (dataset_split == problem.DatasetSplit.TRAIN) 27 | _eval = (dataset_split == problem.DatasetSplit.EVAL) 28 | 29 | dataset = "train" if _train else "val" if _eval else "test" 30 | 31 | src = directory + "src-" + dataset + ".txt" 32 | tgt = directory + "tgt-" + dataset + ".txt" 33 | 34 | f_src = open(src,'r') 35 | f_tgt = open(tgt,'r') 36 | 37 | context_data = f_src.readline() 38 | discharge_summary = f_tgt.readline() 39 | 40 | while context_data: 41 | yield { 42 | "inputs" : context_data, 43 | "targets" : discharge_summary, 44 | } 45 | 46 | context_data = f_src.readline() 47 | discharge_summary = f_tgt.readline() 48 | 49 | f_src.close() 50 | f_tgt.close() 51 | 52 | @property 53 | def vocab_type(self): 54 | # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good 55 | # tradeoff between CHARACTER and TOKEN. 56 | return text_problems.VocabType.SUBWORD 57 | 58 | @property 59 | def approx_vocab_size(self): 60 | # Approximate vocab size to generate. Only for VocabType.SUBWORD. 61 | return 2**15 # ~32k - this is the default setting 62 | 63 | @property 64 | def dataset_splits(self): 65 | return [{ 66 | "split": problem.DatasetSplit.TRAIN, 67 | "shards": 80 68 | }, { 69 | "split": problem.DatasetSplit.EVAL, 70 | "shards": 10 71 | }, { 72 | "split": problem.DatasetSplit.TEST, 73 | "shards": 10 74 | }] 75 | -------------------------------------------------------------------------------- /transformer/alternative/define_problem.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #tf.enable_eager_execution() 3 | from tensor2tensor.utils import trainer_lib 4 | RANDOM_SEED = 301 5 | trainer_lib.set_random_seed(RANDOM_SEED) 6 | from tensor2tensor.data_generators import problem 7 | from tensor2tensor.data_generators import text_problems 8 | from tensor2tensor.utils import registry 9 | 10 | 11 | @registry.register_problem 12 | 13 | class MimicDischargeSummaries(text_problems.Text2TextProblem): 14 | 15 | @property 16 | def is_generate_per_split(self): 17 | # our data already has pre-existing splits 18 | return True 19 | 20 | def generate_samples(self, data_dir, tmp_dir, dataset_split): 21 | 22 | del tmp_dir 23 | 24 | directory = "data/preprocessed/" 25 | 26 | _train = (dataset_split == problem.DatasetSplit.TRAIN) 27 | _eval = (dataset_split == problem.DatasetSplit.EVAL) 28 | 29 | dataset = "train" if _train else "val" if _eval else "test" 30 | 31 | src = directory + "src-" + dataset + ".txt" 32 | tgt = directory + "tgt-" + dataset + ".txt" 33 | 34 | f_src = open(src,'r') 35 | f_tgt = open(tgt,'r') 36 | 37 | context_data = f_src.readline() 38 | discharge_summary = f_tgt.readline() 39 | 40 | while context_data: 41 | yield { 42 | "inputs" : context_data, 43 | "targets" : discharge_summary, 44 | } 45 | 46 | context_data = f_src.readline() 47 | discharge_summary = f_tgt.readline() 48 | 49 | f_src.close() 50 | f_tgt.close() 51 | 52 | @property 53 | def vocab_type(self): 54 | # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good 55 | # tradeoff between CHARACTER and TOKEN. 56 | return text_problems.VocabType.SUBWORD 57 | 58 | @property 59 | def approx_vocab_size(self): 60 | # Approximate vocab size to generate. Only for VocabType.SUBWORD. 61 | return 2**15 # ~32k - this is the default setting 62 | 63 | @property 64 | def dataset_splits(self): 65 | return [{ 66 | "split": problem.DatasetSplit.TRAIN, 67 | "shards": 80 68 | }, { 69 | "split": problem.DatasetSplit.EVAL, 70 | "shards": 10 71 | }, { 72 | "split": problem.DatasetSplit.TEST, 73 | "shards": 10 74 | }] 75 | -------------------------------------------------------------------------------- /transformer_moe/alternative/define_problem.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #tf.enable_eager_execution() 3 | from tensor2tensor.utils import trainer_lib 4 | RANDOM_SEED = 301 5 | trainer_lib.set_random_seed(RANDOM_SEED) 6 | from tensor2tensor.data_generators import problem 7 | from tensor2tensor.data_generators import text_problems 8 | from tensor2tensor.utils import registry 9 | 10 | 11 | @registry.register_problem 12 | 13 | class MimicDischargeSummaries(text_problems.Text2TextProblem): 14 | 15 | @property 16 | def is_generate_per_split(self): 17 | # our data already has pre-existing splits 18 | return True 19 | 20 | def generate_samples(self, data_dir, tmp_dir, dataset_split): 21 | 22 | del tmp_dir 23 | 24 | directory = "data/preprocessed/" 25 | 26 | _train = (dataset_split == problem.DatasetSplit.TRAIN) 27 | _eval = (dataset_split == problem.DatasetSplit.EVAL) 28 | 29 | dataset = "train" if _train else "val" if _eval else "test" 30 | 31 | src = directory + "src-" + dataset + ".txt" 32 | tgt = directory + "tgt-" + dataset + ".txt" 33 | 34 | f_src = open(src,'r') 35 | f_tgt = open(tgt,'r') 36 | 37 | context_data = f_src.readline() 38 | discharge_summary = f_tgt.readline() 39 | 40 | while context_data: 41 | yield { 42 | "inputs" : context_data, 43 | "targets" : discharge_summary, 44 | } 45 | 46 | context_data = f_src.readline() 47 | discharge_summary = f_tgt.readline() 48 | 49 | f_src.close() 50 | f_tgt.close() 51 | 52 | @property 53 | def vocab_type(self): 54 | # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good 55 | # tradeoff between CHARACTER and TOKEN. 56 | return text_problems.VocabType.SUBWORD 57 | 58 | @property 59 | def approx_vocab_size(self): 60 | # Approximate vocab size to generate. Only for VocabType.SUBWORD. 61 | return 2**15 # ~32k - this is the default setting 62 | 63 | @property 64 | def dataset_splits(self): 65 | return [{ 66 | "split": problem.DatasetSplit.TRAIN, 67 | "shards": 80 68 | }, { 69 | "split": problem.DatasetSplit.EVAL, 70 | "shards": 10 71 | }, { 72 | "split": problem.DatasetSplit.TEST, 73 | "shards": 10 74 | }] 75 | -------------------------------------------------------------------------------- /transformer_moe/05_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE=$HOME/project/text-generation # change this as necessary 4 | PREPROC=$BASE/data/preprocessed 5 | TEST_FILE=$PREPROC/src-test.txt 6 | 7 | split -n l/4 $TEST_FILE $PREPROC/temp_ 8 | for f in $PREPROC/temp_*; do mv $f $PREPROC/`basename $f `.txt; done; 9 | 10 | PROBLEM=mimic_discharge_summaries 11 | MODEL=transformer_moe 12 | HPARAMS=transformer_moe_base 13 | DATA_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/data 14 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer_moe/full_context/output 15 | USR_DIR=$BASE/transformer_moe 16 | 17 | DECODE_FILE_0=$PREPROC/temp_aa.txt 18 | DECODE_FILE_1=$PREPROC/temp_ab.txt 19 | DECODE_FILE_2=$PREPROC/temp_ac.txt 20 | DECODE_FILE_3=$PREPROC/temp_ad.txt 21 | 22 | BEAM_SIZE=3 23 | ALPHA=0.6 24 | DBS=1 25 | EXTRA_LEN=50 26 | HPARAMS_OVERRIDE="" 27 | 28 | CUDA_VISIBLE_DEVICES=0 t2t-decoder \ 29 | --t2t_usr_dir=$USR_DIR \ 30 | --data_dir=$DATA_DIR \ 31 | --problem=$PROBLEM \ 32 | --model=$MODEL \ 33 | --hparams_set=$HPARAMS \ 34 | --hparams=$HPARAMS_OVERRIDE \ 35 | --output_dir=$TRAIN_DIR \ 36 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 37 | --decode_from_file=$DECODE_FILE_0 \ 38 | --decode_to_file=$TRAIN_DIR/tgt-decoded-0.txt & 39 | 40 | CUDA_VISIBLE_DEVICES=1 t2t-decoder \ 41 | --t2t_usr_dir=$USR_DIR \ 42 | --data_dir=$DATA_DIR \ 43 | --problem=$PROBLEM \ 44 | --model=$MODEL \ 45 | --hparams_set=$HPARAMS \ 46 | --hparams=$HPARAMS_OVERRIDE \ 47 | --output_dir=$TRAIN_DIR \ 48 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 49 | --decode_from_file=$DECODE_FILE_1 \ 50 | --decode_to_file=$TRAIN_DIR/tgt-decoded-1.txt & 51 | 52 | CUDA_VISIBLE_DEVICES=2 t2t-decoder \ 53 | --t2t_usr_dir=$USR_DIR \ 54 | --data_dir=$DATA_DIR \ 55 | --problem=$PROBLEM \ 56 | --model=$MODEL \ 57 | --hparams_set=$HPARAMS \ 58 | --hparams=$HPARAMS_OVERRIDE \ 59 | --output_dir=$TRAIN_DIR \ 60 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 61 | --decode_from_file=$DECODE_FILE_2 \ 62 | --decode_to_file=$TRAIN_DIR/tgt-decoded-2.txt & 63 | 64 | CUDA_VISIBLE_DEVICES=3 t2t-decoder \ 65 | --t2t_usr_dir=$USR_DIR \ 66 | --data_dir=$DATA_DIR \ 67 | --problem=$PROBLEM \ 68 | --model=$MODEL \ 69 | --hparams_set=$HPARAMS \ 70 | --hparams=$HPARAMS_OVERRIDE \ 71 | --output_dir=$TRAIN_DIR \ 72 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 73 | --decode_from_file=$DECODE_FILE_3 \ 74 | --decode_to_file=$TRAIN_DIR/tgt-decoded-3.txt & 75 | -------------------------------------------------------------------------------- /transformer/06_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE=$HOME/project/text-generation # change this as necessary 4 | PREPROC=$BASE/data/preprocessed/low_resource 5 | TEST_FILE=$PREPROC/src-test.txt 6 | 7 | split -n l/4 $TEST_FILE $PREPROC/temp_ 8 | for f in $PREPROC/temp_*; do mv $f $PREPROC/`basename $f `.txt; done; 9 | 10 | PROBLEM=mimic_discharge_summaries 11 | MODEL=transformer 12 | HPARAMS=transformer_base 13 | DATA_DIR=$BASE/data/t2t_experiments/transformer/low_resource/full_context/data 14 | TRAIN_DIR=$BASE/data/t2t_experiments/transformer/low_resource/full_context/output 15 | USR_DIR=$BASE/transformer 16 | 17 | DECODE_FILE_0=$PREPROC/temp_aa.txt 18 | DECODE_FILE_1=$PREPROC/temp_ab.txt 19 | DECODE_FILE_2=$PREPROC/temp_ac.txt 20 | DECODE_FILE_3=$PREPROC/temp_ad.txt 21 | 22 | BEAM_SIZE=4 23 | ALPHA=0.6 24 | DBS=4 25 | EXTRA_LEN=50 26 | HPARAMS_OVERRIDE="max_length=10000,max_target_seq_length=512,max_input_seq_length=512" 27 | 28 | CUDA_VISIBLE_DEVICES=0 t2t-decoder \ 29 | --t2t_usr_dir=$USR_DIR \ 30 | --data_dir=$DATA_DIR \ 31 | --problem=$PROBLEM \ 32 | --model=$MODEL \ 33 | --hparams_set=$HPARAMS \ 34 | --hparams=$HPARAMS_OVERRIDE \ 35 | --output_dir=$TRAIN_DIR \ 36 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 37 | --decode_from_file=$DECODE_FILE_0 \ 38 | --decode_to_file=$TRAIN_DIR/tgt-decoded-0.txt & 39 | 40 | CUDA_VISIBLE_DEVICES=1 t2t-decoder \ 41 | --t2t_usr_dir=$USR_DIR \ 42 | --data_dir=$DATA_DIR \ 43 | --problem=$PROBLEM \ 44 | --model=$MODEL \ 45 | --hparams_set=$HPARAMS \ 46 | --hparams=$HPARAMS_OVERRIDE \ 47 | --output_dir=$TRAIN_DIR \ 48 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 49 | --decode_from_file=$DECODE_FILE_1 \ 50 | --decode_to_file=$TRAIN_DIR/tgt-decoded-1.txt & 51 | 52 | CUDA_VISIBLE_DEVICES=2 t2t-decoder \ 53 | --t2t_usr_dir=$USR_DIR \ 54 | --data_dir=$DATA_DIR \ 55 | --problem=$PROBLEM \ 56 | --model=$MODEL \ 57 | --hparams_set=$HPARAMS \ 58 | --hparams=$HPARAMS_OVERRIDE \ 59 | --output_dir=$TRAIN_DIR \ 60 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 61 | --decode_from_file=$DECODE_FILE_2 \ 62 | --decode_to_file=$TRAIN_DIR/tgt-decoded-2.txt & 63 | 64 | CUDA_VISIBLE_DEVICES=3 t2t-decoder \ 65 | --t2t_usr_dir=$USR_DIR \ 66 | --data_dir=$DATA_DIR \ 67 | --problem=$PROBLEM \ 68 | --model=$MODEL \ 69 | --hparams_set=$HPARAMS \ 70 | --hparams=$HPARAMS_OVERRIDE \ 71 | --output_dir=$TRAIN_DIR \ 72 | --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,batch_size=$DBS,extra_length=$EXTRA_LEN" \ 73 | --decode_from_file=$DECODE_FILE_3 \ 74 | --decode_to_file=$TRAIN_DIR/tgt-decoded-3.txt & 75 | -------------------------------------------------------------------------------- /low_resource/01_format.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "##### We need to get our previously processed data into the right format for the low resource scenario" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import os\n", 19 | "from pathlib import Path" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "DATA = Path('../data/preprocessed')\n", 29 | "LOW_RESOURCE=DATA/'low_resource'\n", 30 | "LOW_RESOURCE.mkdir(parents=True, exist_ok=True)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "We don't need to change the test and validation files so we can just copy them across" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "!cp ../data/preprocessed/ref_test.tsv ../data/preprocessed/low_resource/\n", 47 | "!cp ../data/preprocessed/ref_val.tsv ../data/preprocessed/low_resource/\n", 48 | "!cp ../data/preprocessed/src-test.txt ../data/preprocessed/low_resource/\n", 49 | "!cp ../data/preprocessed/src-val.txt ../data/preprocessed/low_resource/\n", 50 | "!cp ../data/preprocessed/tgt-test.txt ../data/preprocessed/low_resource/\n", 51 | "!cp ../data/preprocessed/tgt-val.txt ../data/preprocessed/low_resource/" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "However, we will have to reduce the size of the train dataset to ~2m words in accordance with Wikitext2. Since our models cannot process more than 512 words per sample, we can simply divide 2m by 512 to get an approximate number of samples" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Wikitext2 is actually 2,088,628 words to be exact" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "4079.3515625" 77 | ] 78 | }, 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "2088628/512" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "For simplicity let's round this to the first 4000 samples:" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 7, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "!head -4000 ../data/preprocessed/ref_train.tsv > ../data/preprocessed/low_resource/ref_train.tsv \n", 102 | "!head -4000 ../data/preprocessed/src-train.txt > ../data/preprocessed/low_resource/src-train.txt \n", 103 | "!head -4000 ../data/preprocessed/tgt-train.txt > ../data/preprocessed/low_resource/tgt-train.txt " 104 | ] 105 | } 106 | ], 107 | "metadata": { 108 | "kernelspec": { 109 | "display_name": "Python 3", 110 | "language": "python", 111 | "name": "python3" 112 | }, 113 | "language_info": { 114 | "codemirror_mode": { 115 | "name": "ipython", 116 | "version": 3 117 | }, 118 | "file_extension": ".py", 119 | "mimetype": "text/x-python", 120 | "name": "python", 121 | "nbconvert_exporter": "python", 122 | "pygments_lexer": "ipython3", 123 | "version": "3.7.3" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 2 128 | } 129 | -------------------------------------------------------------------------------- /gpt2/generate_conditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import model, sample, encoder 10 | 11 | def sample_model( 12 | model_name='117M', 13 | seed=None, 14 | nsamples=1, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | top_p=0.0, 20 | input_file=None, 21 | output_file="conditional_sample_output.txt" 22 | ): 23 | """ 24 | Run the sample_model 25 | :model_name=117M : String, which model to use 26 | :seed=None : Integer seed for random number generators, fix seed to 27 | reproduce results 28 | :nsamples=1 : Number of samples to return, if 0, continues to 29 | generate samples indefinately. 30 | :batch_size=1 : Number of batches (only affects speed/memory). 31 | :length=None : Number of tokens in generated text, if None (default), is 32 | determined by model hyperparameters 33 | :temperature=1 : Float value controlling randomness in boltzmann 34 | distribution. Lower temperature results in less random completions. As the 35 | temperature approaches zero, the model will become deterministic and 36 | repetitive. Higher temperature results in more random completions. 37 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 38 | considered for each step (token), resulting in deterministic completions, 39 | while 40 means 40 words are considered at each step. 0 (default) is a 40 | special setting meaning no restrictions. 40 generally is a good value. 41 | :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, 42 | overriding top_k if set to a value > 0. A good setting is 0.9. 43 | :input_file=None : Input file as a path to a text file with each line 44 | containing context string for the model to generate upon. Must be provided 45 | :output_file="conditional_sample_output.txt" : Outut file as an optional 46 | file path for the output samples. One sample saved per line. 47 | """ 48 | 49 | if os.path.isfile(input_file) == False: 50 | if input_file is None: 51 | raise ValueError("Please provide an input file") 52 | else: 53 | raise ValueError("Input file not found") 54 | 55 | enc = encoder.get_encoder(model_name) 56 | hparams = model.default_hparams() 57 | with open(os.path.join('models', model_name, 'hparams.json')) as f: 58 | hparams.override_from_dict(json.load(f)) 59 | 60 | if length is None: 61 | length = hparams.n_ctx // 2 62 | elif length > hparams.n_ctx: 63 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 64 | 65 | with tf.Session(graph=tf.Graph()) as sess: 66 | context = tf.placeholder(tf.int32, [batch_size, None]) 67 | np.random.seed(seed) 68 | tf.set_random_seed(seed) 69 | 70 | output = sample.sample_sequence( 71 | hparams=hparams, length=length, 72 | context=context, 73 | batch_size=batch_size, 74 | temperature=temperature, top_k=top_k, top_p=top_p 75 | ) 76 | 77 | saver = tf.train.Saver() 78 | ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) 79 | saver.restore(sess, ckpt) 80 | 81 | output_file = open(output_file,"w+") 82 | 83 | with open(input_file) as f: 84 | for line in f: 85 | context_tokens = enc.encode(line) 86 | for _ in range(nsamples // batch_size): 87 | out = sess.run(output, feed_dict={ 88 | context: [context_tokens for _ in range(batch_size)] 89 | })[:, len(context_tokens):] 90 | for i in range(batch_size): 91 | text = enc.decode(out[i]) 92 | output_file.write(repr(text) + "\n") 93 | 94 | output_file.close() 95 | 96 | if __name__ == '__main__': 97 | fire.Fire(sample_model) 98 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tf 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _tflow_select=2.1.0=gpu 7 | - absl-py=0.7.1=py37_0 8 | - astor=0.8.0=py37_0 9 | - attrs=19.1.0=py37_1 10 | - backcall=0.1.0=py37_0 11 | - blas=1.0=mkl 12 | - bleach=3.1.0=py37_0 13 | - c-ares=1.15.0=h7b6447c_1001 14 | - ca-certificates=2019.5.15=1 15 | - certifi=2019.6.16=py37_1 16 | - cudatoolkit=10.1.168=0 17 | - cudnn=7.6.0=cuda10.1_0 18 | - cupti=10.1.168=0 19 | - dbus=1.13.6=h746ee38_0 20 | - decorator=4.4.0=py37_1 21 | - defusedxml=0.6.0=py_0 22 | - entrypoints=0.3=py37_0 23 | - expat=2.2.6=he6710b0_0 24 | - fontconfig=2.13.0=h9420a91_0 25 | - freetype=2.9.1=h8a8886c_1 26 | - gast=0.2.2=py37_0 27 | - glib=2.56.2=hd408876_0 28 | - gmp=6.1.2=h6c8ec71_1 29 | - google-pasta=0.1.7=py_0 30 | - grpcio=1.16.1=py37hf8bcb03_1 31 | - gst-plugins-base=1.14.0=hbbd80ab_1 32 | - gstreamer=1.14.0=hb453b48_1 33 | - h5py=2.9.0=py37h7918eee_0 34 | - hdf5=1.10.4=hb1b8bf9_0 35 | - icu=58.2=h9c2bf20_1 36 | - intel-openmp=2019.4=243 37 | - ipykernel=5.1.2=py37h39e3cac_0 38 | - ipython=7.7.0=py37h39e3cac_0 39 | - ipython_genutils=0.2.0=py37_0 40 | - ipywidgets=7.5.1=py_0 41 | - jedi=0.15.1=py37_0 42 | - jinja2=2.10.1=py37_0 43 | - jpeg=9b=h024ee3a_2 44 | - jsonschema=3.0.2=py37_0 45 | - jupyter=1.0.0=py37_7 46 | - jupyter_client=5.3.1=py_0 47 | - jupyter_console=6.0.0=py37_0 48 | - jupyter_core=4.5.0=py_0 49 | - keras-applications=1.0.8=py_0 50 | - keras-preprocessing=1.1.0=py_1 51 | - libedit=3.1.20181209=hc058e9b_0 52 | - libffi=3.2.1=hd88cf55_4 53 | - libgcc-ng=9.1.0=hdf63c60_0 54 | - libgfortran-ng=7.3.0=hdf63c60_0 55 | - libpng=1.6.37=hbc83047_0 56 | - libprotobuf=3.8.0=hd408876_0 57 | - libsodium=1.0.16=h1bed415_0 58 | - libstdcxx-ng=9.1.0=hdf63c60_0 59 | - libuuid=1.0.3=h1bed415_2 60 | - libxcb=1.13=h1bed415_1 61 | - libxml2=2.9.9=hea5a465_1 62 | - markdown=3.1.1=py37_0 63 | - markupsafe=1.1.1=py37h7b6447c_0 64 | - mistune=0.8.4=py37h7b6447c_0 65 | - mkl=2019.4=243 66 | - mkl-service=2.0.2=py37h7b6447c_0 67 | - mkl_fft=1.0.14=py37ha843d7b_0 68 | - mkl_random=1.0.2=py37hd81dba3_0 69 | - nb_conda=2.2.1=py37_0 70 | - nb_conda_kernels=2.2.2=py37_0 71 | - nbconvert=5.5.0=py_0 72 | - nbformat=4.4.0=py37_0 73 | - ncurses=6.1=he6710b0_1 74 | - notebook=6.0.0=py37_0 75 | - numpy=1.16.4=py37h7e9f1db_0 76 | - numpy-base=1.16.4=py37hde5b4d6_0 77 | - openssl=1.1.1c=h7b6447c_1 78 | - pandoc=2.2.3.2=0 79 | - pandocfilters=1.4.2=py37_1 80 | - parso=0.5.1=py_0 81 | - pcre=8.43=he6710b0_0 82 | - pexpect=4.7.0=py37_0 83 | - pickleshare=0.7.5=py37_0 84 | - pip=19.2.2=py37_0 85 | - prometheus_client=0.7.1=py_0 86 | - prompt_toolkit=2.0.9=py37_0 87 | - protobuf=3.8.0=py37he6710b0_0 88 | - ptyprocess=0.6.0=py37_0 89 | - pygments=2.4.2=py_0 90 | - pyqt=5.9.2=py37h05f1152_2 91 | - pyrsistent=0.14.11=py37h7b6447c_0 92 | - python=3.7.3=h0371630_0 93 | - python-dateutil=2.8.0=py37_0 94 | - pyzmq=18.1.0=py37he6710b0_0 95 | - qt=5.9.7=h5867ecd_1 96 | - qtconsole=4.5.3=py_0 97 | - readline=7.0=h7b6447c_5 98 | - scipy=1.3.1=py37h7c811a0_0 99 | - send2trash=1.5.0=py37_0 100 | - setuptools=41.0.1=py37_0 101 | - sip=4.19.8=py37hf484d3e_0 102 | - six=1.12.0=py37_0 103 | - sqlite=3.29.0=h7b6447c_0 104 | - tensorboard=1.14.0=py37hf484d3e_0 105 | - tensorflow=1.14.0=gpu_py37h74c33d7_0 106 | - tensorflow-base=1.14.0=gpu_py37he45bfe2_0 107 | - tensorflow-estimator=1.14.0=py_0 108 | - tensorflow-gpu=1.14.0=h0d30ee6_0 109 | - termcolor=1.1.0=py37_1 110 | - terminado=0.8.2=py37_0 111 | - testpath=0.4.2=py37_0 112 | - tk=8.6.8=hbc83047_0 113 | - tornado=6.0.3=py37h7b6447c_0 114 | - traitlets=4.3.2=py37_0 115 | - wcwidth=0.1.7=py37_0 116 | - webencodings=0.5.1=py37_1 117 | - werkzeug=0.15.5=py_0 118 | - wheel=0.33.4=py37_0 119 | - widgetsnbextension=3.5.1=py37_0 120 | - wrapt=1.11.2=py37h7b6447c_0 121 | - xz=5.2.4=h14c3975_4 122 | - zeromq=4.3.1=he6710b0_3 123 | - zlib=1.2.11=h7b6447c_3 124 | - pip: 125 | - blessings==1.7 126 | - bz2file==0.98 127 | - cachetools==3.1.1 128 | - chardet==3.0.4 129 | - click==7.0 130 | - cloudpickle==1.2.1 131 | - dill==0.3.0 132 | - dopamine-rl==2.0.5 133 | - fire==0.2.1 134 | - flask==1.1.1 135 | - future==0.17.1 136 | - gevent==1.4.0 137 | - gin-config==0.2.0 138 | - google-api-python-client==1.7.10 139 | - google-auth==1.6.3 140 | - google-auth-httplib2==0.0.3 141 | - googleapis-common-protos==1.6.0 142 | - gpustat==0.6.0 143 | - greenlet==0.4.15 144 | - gunicorn==19.9.0 145 | - gym==0.14.0 146 | - httplib2==0.13.1 147 | - idna==2.8 148 | - itsdangerous==1.1.0 149 | - kfac==0.1.4 150 | - mesh-tensorflow==0.0.5 151 | - mpmath==1.1.0 152 | - nvidia-ml-py3==7.352.0 153 | - oauth2client==4.1.3 154 | - opencv-python==4.1.0.25 155 | - pandas==0.25.0 156 | - pillow==6.1.0 157 | - promise==2.2.1 158 | - psutil==5.6.3 159 | - pyasn1==0.4.5 160 | - pyasn1-modules==0.2.5 161 | - pyglet==1.3.2 162 | - pypng==0.0.20 163 | - pyrouge==0.1.3 164 | - pytz==2019.2 165 | - regex==2017.4.5 166 | - requests==2.21.0 167 | - rsa==4.0 168 | - sympy==1.4 169 | - tensor2tensor==1.13.4 170 | - tensorflow-datasets==1.2.0 171 | - tensorflow-metadata==0.14.0 172 | - tensorflow-probability==0.7.0 173 | - toposort==1.5 174 | - tqdm==4.31.1 175 | - uritemplate==3.0.0 176 | - urllib3==1.24.3 177 | prefix: /home/aa5118/anaconda3/envs/tf 178 | 179 | -------------------------------------------------------------------------------- /gpt2/01_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Preprocess" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import os\n", 19 | "from pathlib import Path" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "low_resource=True # change as appropriate\n", 29 | "\n", 30 | "if low_resource == True:\n", 31 | " DATA = Path('../data/preprocessed/low_resource/')\n", 32 | " GPT2 = Path('../data/gpt2/low_resource')\n", 33 | "else:\n", 34 | " DATA = Path('../data/preprocessed/')\n", 35 | " GPT2 = Path('../data/gpt2')\n", 36 | " \n", 37 | "GPT2.mkdir(parents=True, exist_ok=True)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Training set" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "with open(DATA/'src-train.txt', 'r') as f:\n", 54 | " train_src = f.readlines()\n", 55 | "train_src=pd.DataFrame({'text':train_src})" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "with open(DATA/'tgt-train.txt', 'r') as f:\n", 65 | " train_tgt = f.readlines()\n", 66 | "train_tgt=pd.DataFrame({'text':train_tgt})" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 5, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "for i, row in train_src.iterrows():\n", 76 | " src = row['text'][:-1]\n", 77 | " src = src.split()[:512]\n", 78 | " src_len = len(src)\n", 79 | " tgt_len = 1024 - src_len #gpt2 can only process inputs of max 1024 tokens in length\n", 80 | " tgt = train_tgt['text'][i]\n", 81 | " tgt = tgt.split()[:tgt_len]\n", 82 | " combined = \"<|startoftext|>\" + \" \".join(src) + \" = \" + \" \".join(tgt) + \"<|endoftext|>\"\n", 83 | " row['text'] = combined" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 15, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "train_src['text'][10]" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 7, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "np.savetxt(GPT2/'input-text.txt', train_src, fmt='%s', newline=os.linesep)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "### Validation set" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 8, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "with open(DATA/'src-val.txt', 'r') as f:\n", 118 | " val_src = f.readlines()\n", 119 | "val_src=pd.DataFrame({'text':val_src})" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 9, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "with open(DATA/'tgt-val.txt', 'r') as f:\n", 129 | " val_tgt = f.readlines()\n", 130 | "val_tgt=pd.DataFrame({'text':val_tgt})" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 10, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "for i, row in val_src.iterrows():\n", 140 | " src = row['text'][:-1]\n", 141 | " src = src.split()[:512]\n", 142 | " src_len = len(src)\n", 143 | " tgt_len = 1024 - src_len #gpt2 can only process inputs of max 1024 tokens in length\n", 144 | " tgt = val_tgt['text'][i]\n", 145 | " tgt = tgt.split()[:tgt_len]\n", 146 | " combined = \"<|startoftext|>\" + \" \".join(src) + \" = \" + \" \".join(tgt) + \"<|endoftext|>\"\n", 147 | " row['text'] = combined" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 11, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "np.savetxt(GPT2/'val-input-text.txt', val_src, fmt='%s', newline=os.linesep)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### Test set" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 12, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "with open(DATA/'src-test.txt', 'r') as f:\n", 173 | " test_src = f.readlines()\n", 174 | "test_src=pd.DataFrame({'text':test_src})" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 13, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "for i, row in test_src.iterrows():\n", 184 | " src = row['text'][:-1]\n", 185 | " src = src.split()[:512]\n", 186 | " src_len = len(src)\n", 187 | " combined = \"<|startoftext|>\" + \" \".join(src) + \" = \"\n", 188 | " row['text'] = combined" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 14, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "np.savetxt(GPT2/'test-input-text.txt', test_src, fmt='%s', newline=os.linesep)" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python [conda env:tf] *", 204 | "language": "python", 205 | "name": "conda-env-tf-py" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.7.3" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 2 222 | } 223 | -------------------------------------------------------------------------------- /analysis/scispacy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import spacy\n", 10 | "import scispacy\n", 11 | "from scispacy.abbreviation import AbbreviationDetector\n", 12 | "from scispacy.umls_linking import UmlsEntityLinker" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "nlp1 = spacy.load(\"en\")\n", 22 | "nlp2 = spacy.load(\"en_core_sci_md\")" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "abbreviation_pipe = AbbreviationDetector(nlp2)\n", 32 | "nlp2.add_pipe(abbreviation_pipe)\n", 33 | "linker = UmlsEntityLinker()\n", 34 | "nlp2.add_pipe(linker)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 11, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "text = \"Alterations in the France Paris Obama hypocretin receptor 2 and preprohypocretin genes produce narcolepsy in some animals\"" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 12, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "doc1 = nlp1.tokenizer(text)\n", 53 | "doc2 = nlp2.tokenizer(text)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 13, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Alterations\n", 66 | "in\n", 67 | "the\n", 68 | "France\n", 69 | "Paris\n", 70 | "Obama\n", 71 | "hypocretin\n", 72 | "receptor\n", 73 | "2\n", 74 | "and\n", 75 | "preprohypocretin\n", 76 | "genes\n", 77 | "produce\n", 78 | "narcolepsy\n", 79 | "in\n", 80 | "some\n", 81 | "animals\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "for token in doc1:\n", 87 | " print (token)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 14, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Alterations\n", 100 | "in\n", 101 | "the\n", 102 | "France\n", 103 | "Paris\n", 104 | "Obama\n", 105 | "hypocretin\n", 106 | "receptor\n", 107 | "2\n", 108 | "and\n", 109 | "preprohypocretin\n", 110 | "genes\n", 111 | "produce\n", 112 | "narcolepsy\n", 113 | "in\n", 114 | "some\n", 115 | "animals\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "for token in doc2:\n", 121 | " print (token)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 15, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "for ent in doc1.ents:\n", 131 | " print(ent.text, ent.label_)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 16, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "for ent in doc2.ents:\n", 141 | " print(ent.text, ent.label_)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 10, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "ename": "IndexError", 151 | "evalue": "tuple index out of range", 152 | "output_type": "error", 153 | "traceback": [ 154 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 155 | "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", 156 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mentity\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdoc2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ments\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 157 | "\u001b[0;31mIndexError\u001b[0m: tuple index out of range" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "entity = doc2.ents[0]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "print (entity)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "print(entity._.umls_ents)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "for umls_ent in entity._.umls_ents:\n", 190 | " print(linker.umls.cui_to_entity[umls_ent[0]].definition)\n", 191 | " print ('\\n')" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "for ent in doc2.ents:\n", 201 | " umls_ent = ent._.umls_ents\n", 202 | " if umls_ent:\n", 203 | " print (ent, \":\", linker.umls.cui_to_entity[umls_ent[0][0]].definition, \"\\n\")" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "for token in nlp1.vocab.strings:\n", 213 | " " 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.7.3" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 2 245 | } 246 | -------------------------------------------------------------------------------- /transformer_moe/02_count_subwords.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Tensor2tensor uses subwords as tokens by default which results in better performance. It also uses steps fo determining the length of training as opposed to epochs. Converting between number of steps and number of epochs is based on the batch effective size (i.e. `effective_batch_size = batch_size * num_of_gpus`) and the number of subwords in a batch such that `epochs = steps * effective_batch_size / training_subwords`" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "However, rather unhelpfully, t2t does not provide us with the number of subwords in our dataset or in an individual sample so in order to convert between steps and epochs, this must be done manually using the vocabulary." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from tensor2tensor.data_generators import text_encoder\n", 24 | "import subprocess as sp" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "W0802 19:09:26.700704 140678470321984 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:940: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 37 | "\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "vocab_filepath = '../data/t2t_experiments/full_context/data/vocab.mimic_discharge_summaries.32768.subwords'\n", 43 | "vocab = text_encoder.SubwordTextEncoder(vocab_filepath)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "text = sp.getoutput('head -1 ../data/preprocessed/src-train.txt')" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 6, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "509\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "print(len([vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(text)]))" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "509 subwords in that example - the final input context in our training set. Let's use this method to work out the total number of subwords in our training dataset" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 15, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "25,740,563\n", 89 | "18,527,996\n", 90 | "23,455,357\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "total_subword_count = 0\n", 96 | "max_512_subword_count = 0\n", 97 | "max_1024_subword_count = 0\n", 98 | "\n", 99 | "f=open(\"../data/preprocessed/src-train.txt\",\"r\")\n", 100 | "\n", 101 | "line = f.readline()\n", 102 | "\n", 103 | "while line:\n", 104 | " \n", 105 | " count = len([vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(line)])\n", 106 | " total_subword_count += count\n", 107 | " max_512_subword_count += min(512,count)\n", 108 | " max_1024_subword_count += min(1024,count)\n", 109 | " line = f.readline()\n", 110 | "\n", 111 | "f.close()\n", 112 | "\n", 113 | "print ('{:,}'.format(total_subword_count))\n", 114 | "print ('{:,}'.format(max_512_subword_count))\n", 115 | "print ('{:,}'.format(max_1024_subword_count))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "~25.7m subwords in source training file consisting ~40k discharge summaries. We can now use this number to calculate how many training steps corresponds to one epoch" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 22, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "num_of_gpus = 4\n", 132 | "batch_size = 4096\n", 133 | "effective_batch_size = num_of_gpus * batch_size\n", 134 | "epochs = 1\n", 135 | "\n", 136 | "def epoch2steps(subword_count):\n", 137 | " steps = epochs / (effective_batch_size/subword_count)\n", 138 | " print (\"1 epoch correponds to\", '{:,}'.format(int(steps)), \"steps\")\n", 139 | " return" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 24, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "1 epoch correponds to 1,571 steps\n", 152 | "1 epoch correponds to 1,130 steps\n", 153 | "1 epoch correponds to 1,431 steps\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "epoch2steps(total_subword_count)\n", 159 | "epoch2steps(max_512_subword_count)\n", 160 | "epoch2steps(max_1024_subword_count)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Liu et al (2018) use 400,000 steps when training their transformer model. With this setup, this would correspond to:" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 27, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "255 epochs\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "print('{0:,}'.format(round(400000 / steps)), \"epochs\") " 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "However, this is not a valid comparison because we are only focusing on discharge summaries whereas Liu at el where looking at the entire dataset of notes of which discharge summaries are only a small percentage." 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "Another important consideration point is the fact that Liu et al (2018) limit/truncate both input and output tokens (subwords) to 512. As shown above, this ends up removing ~7m subwords from our input. Doubling this to 1024 is worth considering as this only removes ~2m words." 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.7.3" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 2 223 | } 224 | -------------------------------------------------------------------------------- /transformer/02_count_subwords.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Tensor2tensor uses subwords as tokens by default which results in better performance. It also uses steps fo determining the length of training as opposed to epochs. Converting between number of steps and number of epochs is based on the batch effective size (i.e. `effective_batch_size = batch_size * num_of_gpus`) and the number of subwords in a batch such that `epochs = steps * effective_batch_size / training_subwords`" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "However, rather unhelpfully, t2t does not provide us with the number of subwords in our dataset or in an individual sample so in order to convert between steps and epochs, this must be done manually using the vocabulary." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from tensor2tensor.data_generators import text_encoder\n", 24 | "import subprocess as sp" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "WARNING: Logging before flag parsing goes to stderr.\n", 37 | "W0904 18:32:46.797372 139836362250048 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:938: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.\n", 38 | "\n", 39 | "W0904 18:32:46.798849 139836362250048 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:940: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 40 | "\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "vocab_filepath = '../data/t2t_experiments/transformer/low_resource/full_context/data/vocab.mimic_discharge_summaries.32768.subwords'\n", 46 | "vocab = text_encoder.SubwordTextEncoder(vocab_filepath)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "text = sp.getoutput('head -1 ../data/preprocessed/src-train.txt')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "510\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "print(len([vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(text)]))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "509 subwords in that example - the final input context in our training set. Let's use this method to work out the total number of subwords in our training dataset" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "2,357,469\n", 92 | "1,726,779\n", 93 | "2,196,080\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "total_subword_count = 0\n", 99 | "max_512_subword_count = 0\n", 100 | "max_1024_subword_count = 0\n", 101 | "\n", 102 | "f=open(\"../data/preprocessed/low_resource/src-train.txt\",\"r\")\n", 103 | "\n", 104 | "line = f.readline()\n", 105 | "\n", 106 | "while line:\n", 107 | " \n", 108 | " count = len([vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(line)])\n", 109 | " total_subword_count += count\n", 110 | " max_512_subword_count += min(512,count)\n", 111 | " max_1024_subword_count += min(1024,count)\n", 112 | " line = f.readline()\n", 113 | "\n", 114 | "f.close()\n", 115 | "\n", 116 | "print ('{:,}'.format(total_subword_count))\n", 117 | "print ('{:,}'.format(max_512_subword_count))\n", 118 | "print ('{:,}'.format(max_1024_subword_count))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "~25.7m subwords in source training file consisting ~40k discharge summaries. We can now use this number to calculate how many training steps corresponds to one epoch" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 7, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "num_of_gpus = 4\n", 135 | "batch_size = 4096\n", 136 | "effective_batch_size = num_of_gpus * batch_size\n", 137 | "epochs = 1\n", 138 | "\n", 139 | "def epoch2steps(subword_count):\n", 140 | " steps = epochs / (effective_batch_size/subword_count)\n", 141 | " print (\"1 epoch correponds to\", '{:,}'.format(int(steps)), \"steps\")\n", 142 | " return" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 8, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "1 epoch correponds to 143 steps\n", 155 | "1 epoch correponds to 105 steps\n", 156 | "1 epoch correponds to 134 steps\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "epoch2steps(total_subword_count)\n", 162 | "epoch2steps(max_512_subword_count)\n", 163 | "epoch2steps(max_1024_subword_count)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "Liu et al (2018) use 400,000 steps when training their transformer model. With this setup, this would correspond to:" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 27, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "255 epochs\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "print('{0:,}'.format(round(400000 / steps)), \"epochs\") " 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "However, this is not a valid comparison because we are only focusing on discharge summaries whereas Liu at el where looking at the entire dataset of notes of which discharge summaries are only a small percentage." 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "Another important consideration point is the fact that Liu et al (2018) limit/truncate both input and output tokens (subwords) to 512. As shown above, this ends up removing ~7m subwords from our input. Doubling this to 1024 is worth considering as this only removes ~2m words." 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python [conda env:tf]", 208 | "language": "python", 209 | "name": "conda-env-tf-py" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.7.3" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 2 226 | } 227 | -------------------------------------------------------------------------------- /preprocess/04_context_data_split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Context data split" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook, we will be generating different permutations of the context data for the source files we created in the previous notebook" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import re\n", 24 | "import string\n", 25 | "import os" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def context_data_split(filename, dataset_split):\n", 35 | "\n", 36 | " directory = \"../data/preprocessed/other_contexts/\"\n", 37 | " \n", 38 | " f_h = open(directory + \"src-\" + dataset_split + \"-h.txt\",\"w+\")\n", 39 | " f_h_gae = open(directory + \"src-\" + dataset_split + \"-h-gae.txt\",\"w+\")\n", 40 | " f_h_gae_d = open(directory + \"src-\" + dataset_split + \"-h-gae-d.txt\",\"w+\")\n", 41 | " f_h_gae_p = open(directory + \"src-\" + dataset_split + \"-h-gae-p.txt\",\"w+\")\n", 42 | " f_h_gae_d_p = open(directory + \"src-\" + dataset_split + \"-h-gae-d-p.txt\",\"w+\")\n", 43 | " f_h_gae_d_p_m = open(directory + \"src-\" + dataset_split + \"-h-gae-d-p-m.txt\",\"w+\")\n", 44 | " f_h_gae_d_p_m_t = open(directory + \"src-\" + dataset_split + \"-h-gae-d-p-m-t.txt\",\"w+\")\n", 45 | " f_h_gae_d_p_m_l = open(directory + \"src-\" + dataset_split + \"-h-gae-d-p-m-l.txt\",\"w+\")\n", 46 | " \n", 47 | " f = open(filename,'r')\n", 48 | " line = f.readline()\n", 49 | "\n", 50 | " while line:\n", 51 | " context = re.split(r' | | | | |

| | | ', line)\n", 52 | "\n", 53 | " hint = context[0] + \" \"\n", 54 | " demographic = context[1] + \" \" + context[2] + \" \" + context[3] + \" \"\n", 55 | " diagnosis_list = context[4] + \" \"\n", 56 | " procedure_list = context[5] + \"

\"\n", 57 | " med_list = context[6] + \" \"\n", 58 | " microbio_list = context[7] + \" \"\n", 59 | " lab_list = context[8] + \" \"\n", 60 | "\n", 61 | " H = hint + \"\\n\"\n", 62 | " H_GAE = hint + demographic + \"\\n\"\n", 63 | " H_GAE_D = hint + demographic + diagnosis_list + \"\\n\"\n", 64 | " H_GAE_P = hint + demographic + diagnosis_list + procedure_list + \"\\n\"\n", 65 | " H_GAE_D_P = hint + demographic + diagnosis_list + procedure_list + \"\\n\"\n", 66 | " H_GAE_D_P_M = hint + demographic + diagnosis_list + procedure_list + med_list + \"\\n\"\n", 67 | " H_GAE_D_P_M_T = hint + demographic + diagnosis_list + procedure_list + med_list + microbio_list + \"\\n\"\n", 68 | " H_GAE_D_P_M_L = hint + demographic + diagnosis_list + procedure_list + med_list + lab_list + \"\\n\"\n", 69 | "\n", 70 | " f_h.write(H)\n", 71 | " f_h_gae.write(H_GAE)\n", 72 | " f_h_gae_d.write(H_GAE_D)\n", 73 | " f_h_gae_p.write(H_GAE_P)\n", 74 | " f_h_gae_d_p.write(H_GAE_D_P)\n", 75 | " f_h_gae_d_p_m.write(H_GAE_D_P_M)\n", 76 | " f_h_gae_d_p_m_t.write(H_GAE_D_P_M_T)\n", 77 | " f_h_gae_d_p_m_l.write(H_GAE_D_P_M_L)\n", 78 | "\n", 79 | " line = f.readline()\n", 80 | " \n", 81 | " f_h.close()\n", 82 | " f_h_gae.close()\n", 83 | " f_h_gae_d.close()\n", 84 | " f_h_gae_p.close()\n", 85 | " f_h_gae_d_p.close()\n", 86 | " f_h_gae_d_p_m.close()\n", 87 | " f_h_gae_d_p_m_t.close()\n", 88 | " f_h_gae_d_p_m_l.close()\n", 89 | "\n", 90 | " f.close()\n", 91 | " \n", 92 | " return" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "splitting training data...\n", 105 | "splitting evaluation data...\n", 106 | "splitting test data...\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "print (\"splitting training data...\")\n", 112 | "context_data_split(\"../data/preprocessed/src-train.txt\",\"train\")\n", 113 | "\n", 114 | "print (\"splitting evaluation data...\")\n", 115 | "context_data_split(\"../data/preprocessed/src-val.txt\",\"val\")\n", 116 | "\n", 117 | "print (\"splitting test data...\")\n", 118 | "context_data_split(\"../data/preprocessed/src-test.txt\",\"test\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "Verify that all files have the same number of lines and that the word/character counts are in line with expectation:" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "/mimic/data/preprocessed/other_contexts\n", 138 | " 5727 2571310 12500237 src-test-h-gae-d-p-m-l.txt\n", 139 | " 5727 878941 5527487 src-test-h-gae-d-p-m-t.txt\n", 140 | " 5727 841031 5339036 src-test-h-gae-d-p-m.txt\n", 141 | " 5727 613073 4146137 src-test-h-gae-d-p.txt\n", 142 | " 5727 553588 3724383 src-test-h-gae-d.txt\n", 143 | " 5727 613073 4146137 src-test-h-gae-p.txt\n", 144 | " 5727 97359 442100 src-test-h-gae.txt\n", 145 | " 5727 62997 309842 src-test-h.txt\n", 146 | " 44230 19997807 97595528 src-train-h-gae-d-p-m-l.txt\n", 147 | " 44230 6958046 43839444 src-train-h-gae-d-p-m-t.txt\n", 148 | " 44230 6689483 42521849 src-train-h-gae-d-p-m.txt\n", 149 | " 44230 4827069 32745113 src-train-h-gae-d-p.txt\n", 150 | " 44230 3920392 26260356 src-train-h-gae-d.txt\n", 151 | " 44230 4827069 32745113 src-train-h-gae-p.txt\n", 152 | " 44230 751910 3427563 src-train-h-gae.txt\n", 153 | " 44230 486530 2405680 src-train-h.txt\n", 154 | " 5447 2468944 12034600 src-val-h-gae-d-p-m-l.txt\n", 155 | " 5447 851598 5367720 src-val-h-gae-d-p-m-t.txt\n", 156 | " 5447 817634 5200961 src-val-h-gae-d-p-m.txt\n", 157 | " 5447 590960 4008708 src-val-h-gae-d-p.txt\n", 158 | " 5447 480552 3218051 src-val-h-gae-d.txt\n", 159 | " 5447 590960 4008708 src-val-h-gae-p.txt\n", 160 | " 5447 92599 422324 src-val-h-gae.txt\n", 161 | " 5447 59917 296473 src-val-h.txt\n", 162 | " 443232 60642842 352233550 total\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "%cd ../data/preprocessed/other_contexts/\n", 168 | "!wc -mlw *" 169 | ] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.7.3" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 2 193 | } 194 | -------------------------------------------------------------------------------- /transformer-xl/01_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Preprocess" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import os\n", 19 | "from pathlib import Path" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "low_resource=True # change as appropriate\n", 29 | "\n", 30 | "if low_resource == True:\n", 31 | " DATA = Path('../data/preprocessed/low_resource/')\n", 32 | " MODEL = Path('../data/transformer-xl/low_resource')\n", 33 | "else:\n", 34 | " DATA = Path('../data/preprocessed/')\n", 35 | " MODEL = Path('../data/transformer-xl')\n", 36 | " \n", 37 | "MODEL.mkdir(parents=True, exist_ok=True)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Training set" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "with open(DATA/'src-train.txt', 'r') as f:\n", 54 | " train_src = f.readlines()\n", 55 | "train_src=pd.DataFrame({'text':train_src})" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "with open(DATA/'tgt-train.txt', 'r') as f:\n", 65 | " train_tgt = f.readlines()\n", 66 | "train_tgt=pd.DataFrame({'text':train_tgt})" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 5, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "for i, row in train_src.iterrows():\n", 76 | " src = row['text'][:-1]\n", 77 | " src = src.split()[:512]\n", 78 | " src_len = len(src)\n", 79 | " tgt_len = 1024 - src_len # cap sequence at 1024 tokens in length\n", 80 | " tgt = train_tgt['text'][i]\n", 81 | " tgt = tgt.split()[:tgt_len]\n", 82 | " combined = \"= discharge summary = \" + '\\n' + \" \".join(src) + '\\n' + \"= = note = =\" + '\\n' + \" \".join(tgt)\n", 83 | " row['text'] = combined" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "'= discharge summary = \\nadmission date : [ 2134/9/1 ] [ month / day M 47 black acute respiratory failure | acute diastolic heart failure | atrial flutter | obesity hypoventilation syndrome | obstructive sleep apnea (adult)(pediatric) | unspecified schizophrenia, unspecified | bipolar disorder, unspecified | obesity, unspecified | pure hypercholesterolemia | diabetes mellitus without mention of complication, type ii or unspecified type, not stated as uncontrolled | congestive heart failure, unspecified | atrial fibrillation | polycythemia vera

furosemide , 40mg Tablet | docusate sodium , 100mg Capsule | furosemide , 100mg/10mL Vial Creatine Kinase, MB Isoenzyme , 1 , ng/mL | Urea Nitrogen , 17 , mg/dL | Chloride , 92 , mEq/L , abnormal | Calcium, Total , 8.5 , mg/dL | Bicarbonate , 40 , mEq/L , abnormal | Anion Gap , 12 , mEq/L | Phosphate , 3.6 , mg/dL | Potassium , 3.9 , mEq/L | Sodium , 140 , mEq/L | Hematocrit , 46.6 , % | Creatinine , 0.7 , mg/dL | Hemoglobin , 14.6 , g/dL | MCH , 27.8 , pg | MCHC , 31.4 , % | MCV , 88 , fL | Platelet Count , 198 , K/uL | RDW , 14.8 , % | Red Blood Cells , 5.27 , m/uL | Creatine Kinase (CK) , 44 , IU/L , abnormal | White Blood Cells , 7.7 , K/uL | Glucose , 115 , mg/dL , abnormal | Magnesium , 1.8 , mg/dL | Urea Nitrogen , 15 , mg/dL | Calcium, Total , 8.4 , mg/dL | Chloride , 92 , mEq/L , abnormal | Creatinine , 0.9 , mg/dL | Glucose , 140 , mg/dL , abnormal | Magnesium , 2.1 , mg/dL | Phosphate , 3.7 , mg/dL | Potassium , 4.2 , mEq/L | Sodium , 139 , mEq/L | Hematocrit , 45.5 , % | Bicarbonate , 42 , mEq/L , abnormal | Hemoglobin , 14.3 , g/dL | MCH , 28.9 , pg | Red Blood Cells , 4.97 , m/uL | White Blood Cells , 7.5 , K/uL | MCHC , 31.5 , % | MCV , 92 , fL | Platelet Count , 225 , K/uL | RDW , 15.3 , % | Anion Gap , 9 , mEq/L \\n= = note = =\\nadmission date : [ 2134/9/1 ] [ month / day / year ] date : [ 2134/9/6 ] date of birth : [ 2087/8/23 ] sex : m service : medicine allergies : no known allergies / adverse drug reactions attending : [ first name3 ( lf ) 11040 ] chief complaint : shortness of breath and hypoxia major surgical or invasive procedure : none history of present illness : mr [ known lastname 732 ] is a 47 year old male with h / o schizophrenia / bipolar disorder , obesity , hypercholesterolemia , and severe complex sleep disordered who comes in to from reab facility with increase in sob and hypoxia pt has been on rehab , [ location ( un ) 669 ] west since [ month ( only ) 958 ] of this year for his osa since he lived in a group home and was unable to care for himself as per nursing staff at the facility , he has been having increase in sob and hypoxia for the last 2 weeks for which he was tx with lasix with some improvement he had a recent 3 lbs wt gain in the last few days he is also non - compliment with his diet and tends to drink \" a lot of soda \" he is on lasix 80 mg [ hospital1 ] , however there is no listing of him having chf he denies missing any of his meds he states that he had increse in sob today and his o2sat was 70s% on ra and ems was called he denies having any fever , chills , sweats , or cough , or le edema he sleep with \" pillows \" he was then brought to the ed for further eval in the ed , initial vs were : hr 80 , 140/90 , 25 , 88 - 93 % on bipap probnp : 2129 , wbc was 10.2 ( n : 74 band : 0 l : 10 m : 15 e : 1 bas : 0 ) , trop - t : < 0.01 ua negative he had cxray which as per ed report states continued evidence of left - sided pneumonia as well as diffuse interstitial edema he was given vanco , cefepime and azithromycin for the concern of pneumonia he also had a cta which did not show a pe or other acute lung process he was then given 80 mg of iv lasix which he responded well with 900cc of urine output in the ed and 300cc on arrival to the floor on arrival to the floor pt appears comfortable was sating at 100 % on a non - rebreather he denies having any complains of pain or sob he states to feel \" much better \" than before as per his [ hospital1 1501 ] , he did not have any recent pneumonia , infection or complains of chest pain he was planning to see a cardiologist for concern of chf past medical history : psychiatric history - reports history of bipolar disorder / schizophrenia diagnosed when he was 17 - two hosptialization at [ first name5 ( namepattern1 ) 745 ] [ last name ( namepattern1 ) ] and [ last name ( lf ) 42339 ] , [ first name3 ( lf ) ] mother , last [ name2 ( ni ) 103301 ] was 5 years ago sees [ name2 ( ni ) 2447 ]'" 95 | ] 96 | }, 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "train_src['text'][10]" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 7, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "np.savetxt(MODEL/'input-text.txt', train_src, fmt='%s', newline=os.linesep)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "### Validation set" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 8, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "with open(DATA/'src-val.txt', 'r') as f:\n", 129 | " val_src = f.readlines()\n", 130 | "val_src=pd.DataFrame({'text':val_src})" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "with open(DATA/'tgt-val.txt', 'r') as f:\n", 140 | " val_tgt = f.readlines()\n", 141 | "val_tgt=pd.DataFrame({'text':val_tgt})" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 10, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "for i, row in val_src.iterrows():\n", 151 | " src = row['text'][:-1]\n", 152 | " src = src.split()[:512]\n", 153 | " src_len = len(src)\n", 154 | " tgt_len = 1024 - src_len\n", 155 | " tgt = val_tgt['text'][i]\n", 156 | " tgt = tgt.split()[:tgt_len]\n", 157 | " combined = \"= discharge summary = \" + '\\n' + \" \".join(src) + '\\n' + \"= = note = =\" + '\\n' + \" \".join(tgt)\n", 158 | " row['text'] = combined" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 11, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "np.savetxt(MODEL/'val-input-text.txt', val_src, fmt='%s', newline=os.linesep)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### Test set" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 12, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "with open(DATA/'src-test.txt', 'r') as f:\n", 184 | " test_src = f.readlines()\n", 185 | "test_src=pd.DataFrame({'text':test_src})" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 13, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "for i, row in test_src.iterrows():\n", 195 | " src = row['text'][:-1]\n", 196 | " src = src.split()[:512]\n", 197 | " src_len = len(src)\n", 198 | " combined = \"= discharge summary = \" + '\\n' + \" \".join(src) + '\\n' + \"= = note = =\"\n", 199 | " row['text'] = combined" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 14, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "np.savetxt(MODEL/'test-input-text.txt', test_src, fmt='%s', newline=os.linesep)" 209 | ] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "Python [conda env:tf]", 215 | "language": "python", 216 | "name": "conda-env-tf-py" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.7.3" 229 | } 230 | }, 231 | "nbformat": 4, 232 | "nbformat_minor": 2 233 | } 234 | -------------------------------------------------------------------------------- /transformer_moe/01_generate_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setup" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Before we can start using the `tensor2tensor` models, we first have to get our data into a format that `tensor2tensor` can digest. This means defining a custom `Problem` as follows:" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#@title Run this only once - Sets up TF Eager execution.\n", 24 | "\n", 25 | "import tensorflow as tf\n", 26 | "\n", 27 | "# Enable Eager execution - useful for seeing the generated data.\n", 28 | "tf.enable_eager_execution()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "WARNING: Logging before flag parsing goes to stderr.\n", 41 | "W0802 18:59:01.531362 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/expert_utils.py:68: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n", 42 | "\n", 43 | "W0802 18:59:02.064932 139639908783936 lazy_loader.py:50] \n", 44 | "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", 45 | "For more information, please see:\n", 46 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", 47 | " * https://github.com/tensorflow/addons\n", 48 | " * https://github.com/tensorflow/io (for I/O related ops)\n", 49 | "If you depend on functionality not listed there, please file an issue.\n", 50 | "\n", 51 | "W0802 18:59:02.987431 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/metrics_hook.py:28: The name tf.train.SessionRunHook is deprecated. Please use tf.estimator.SessionRunHook instead.\n", 52 | "\n", 53 | "W0802 18:59:02.991913 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/adafactor.py:27: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 54 | "\n", 55 | "W0802 18:59:02.992884 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/multistep_optimizer.py:32: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n", 56 | "\n", 57 | "W0802 18:59:03.003361 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/trainer_lib.py:109: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n", 58 | "\n", 59 | "W0802 18:59:03.006350 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/trainer_lib.py:780: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.\n", 60 | "\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "#@title Setting a random seed.\n", 66 | "\n", 67 | "from tensor2tensor.utils import trainer_lib\n", 68 | "\n", 69 | "# Set a seed so that we have deterministic outputs.\n", 70 | "RANDOM_SEED = 301\n", 71 | "trainer_lib.set_random_seed(RANDOM_SEED)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "#@title Run for setting up directories.\n", 81 | "\n", 82 | "import os\n", 83 | "\n", 84 | "# Setup and create directories.\n", 85 | "DATA_DIR = os.path.expanduser(\"../data/t2t_experiments/transformer_moe/full_context/data\")\n", 86 | "OUTPUT_DIR = os.path.expanduser(\"../data/t2t_experiments/transformer_moe/full_context/output\")\n", 87 | "TMP_DIR = os.path.expanduser(\"/mnt/\")\n", 88 | "\n", 89 | "# Create them.\n", 90 | "tf.gfile.MakeDirs(DATA_DIR)\n", 91 | "tf.gfile.MakeDirs(OUTPUT_DIR)\n", 92 | "tf.gfile.MakeDirs(TMP_DIR)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "from tensor2tensor.data_generators import problem\n", 102 | "from tensor2tensor.data_generators import text_problems\n", 103 | "from tensor2tensor.utils import registry" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "# Define the problem" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "@registry.register_problem\n", 120 | "\n", 121 | "class MimicDischargeSummaries(text_problems.Text2TextProblem):\n", 122 | " \n", 123 | " @property\n", 124 | " def is_generate_per_split(self):\n", 125 | " # our data already has pre-existing splits so we return true\n", 126 | " return True\n", 127 | "\n", 128 | " def generate_samples(self, data_dir, tmp_dir, dataset_split):\n", 129 | " \n", 130 | " del tmp_dir\n", 131 | " \n", 132 | " _train = (dataset_split == problem.DatasetSplit.TRAIN)\n", 133 | " _eval = (dataset_split == problem.DatasetSplit.EVAL)\n", 134 | " \n", 135 | " dataset = \"train\" if _train else \"val\" if _eval else \"test\"\n", 136 | " \n", 137 | " full_context = \"full_context\" in str(data_dir) # returns a boolean\n", 138 | " directory = \"../data/preprocessed/\"\n", 139 | " tgt = directory + \"tgt-\" + dataset + \".txt\"\n", 140 | "\n", 141 | " if full_context == True:\n", 142 | " src = directory + \"src-\" + dataset + \".txt\"\n", 143 | " else:\n", 144 | " directory += \"other_contexts/\" \n", 145 | " context = str(data_dir)[39:-5] # this index needs to be changed if file paths are changed\n", 146 | " src = directory + \"src-\" + dataset + \"-\" + context + \".txt\"\n", 147 | " \n", 148 | " f_src = open(src,'r')\n", 149 | " f_tgt = open(tgt,'r')\n", 150 | " \n", 151 | " context_data = f_src.readline()\n", 152 | " discharge_summary = f_tgt.readline()\n", 153 | "\n", 154 | " while context_data:\n", 155 | " yield {\n", 156 | " \"inputs\" : context_data,\n", 157 | " \"targets\" : discharge_summary,\n", 158 | " }\n", 159 | " \n", 160 | " context_data = f_src.readline()\n", 161 | " discharge_summary = f_tgt.readline()\n", 162 | " \n", 163 | " f_src.close()\n", 164 | " f_tgt.close()\n", 165 | "\n", 166 | " @property\n", 167 | " def vocab_type(self):\n", 168 | " # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good\n", 169 | " # tradeoff between CHARACTER and TOKEN.\n", 170 | " return text_problems.VocabType.SUBWORD\n", 171 | "\n", 172 | " @property\n", 173 | " def approx_vocab_size(self):\n", 174 | " # Approximate vocab size to generate. Only for VocabType.SUBWORD.\n", 175 | " return 2**15 # ~32k - this is the default setting\n", 176 | "\n", 177 | " @property\n", 178 | " def dataset_splits(self):\n", 179 | " return [{\n", 180 | " \"split\": problem.DatasetSplit.TRAIN,\n", 181 | " \"shards\": 80\n", 182 | " }, {\n", 183 | " \"split\": problem.DatasetSplit.EVAL,\n", 184 | " \"shards\": 10\n", 185 | " }, {\n", 186 | " \"split\": problem.DatasetSplit.TEST,\n", 187 | " \"shards\": 10\n", 188 | " }]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "# Generate the data" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "First, we instantiate the problem and run it for the full context data." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "mimic_problem = MimicDischargeSummaries()\n", 212 | "#mimic_problem.generate_data(DATA_DIR, TMP_DIR)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "Now, we run it in a loop instead for each individual context type." 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 7, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stderr", 229 | "output_type": "stream", 230 | "text": [ 231 | "W0802 19:05:31.212889 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:343: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.\n", 232 | "\n", 233 | "W0802 19:05:31.213634 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:349: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", 234 | "\n", 235 | "W0802 19:08:39.032947 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:355: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.\n", 236 | "\n", 237 | "W0802 19:08:39.034780 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:944: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 238 | "\n", 239 | "W0802 19:08:39.085365 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:164: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", 240 | "\n", 241 | "W0802 19:11:46.960594 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:183: The name tf.gfile.Rename is deprecated. Please use tf.io.gfile.rename instead.\n", 242 | "\n", 243 | "W0802 19:12:35.876786 139639908783936 deprecation.py:323] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:469: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.\n", 244 | "Instructions for updating:\n", 245 | "Use eager execution and: \n", 246 | "`tf.data.TFRecordDataset(path)`\n", 247 | "W0802 19:12:35.887185 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:513: The name tf.gfile.Remove is deprecated. Please use tf.io.gfile.remove instead.\n", 248 | "\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "context_list = ['h','h-gae','h-gae-d','h-gae-p','h-gae-d-p','h-gae-d-p-m','h-gae-d-p-m-t','h-gae-d-p-m-l']\n", 254 | "\n", 255 | "for context in context_list:\n", 256 | " # Setup and create directories.\n", 257 | " DATA_DIR = os.path.expanduser(\"../data/t2t_experiments/other_contexts/\"+context+\"/data\")\n", 258 | " OUTPUT_DIR = os.path.expanduser(\"../data/t2t_experiments/other_contexts/\"+context+\"/output\")\n", 259 | " TMP_DIR = os.path.expanduser(\"/mnt/\")\n", 260 | "\n", 261 | " # Create them.\n", 262 | " tf.gfile.MakeDirs(DATA_DIR)\n", 263 | " tf.gfile.MakeDirs(OUTPUT_DIR)\n", 264 | " tf.gfile.MakeDirs(TMP_DIR)\n", 265 | " \n", 266 | " mimic_problem.generate_data(DATA_DIR, TMP_DIR)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "# View the generated data" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 7, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stderr", 283 | "output_type": "stream", 284 | "text": [ 285 | "W0729 17:59:46.342158 140696631490368 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_problems.py:394: The name tf.VarLenFeature is deprecated. Please use tf.io.VarLenFeature instead.\n", 286 | "\n", 287 | "W0729 17:59:46.343035 140696631490368 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/problem.py:705: The name tf.FixedLenFeature is deprecated. Please use tf.io.FixedLenFeature instead.\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "Tensor Input: tf.Tensor(\n", 296 | "[ 80 72 54 1451 6 42 6 121 18 55 72 398\n", 297 | " 368 3 228 4 365 3 269 4 364 3 4213 4\n", 298 | " 366 3 2468 3414 1737 1534 2069 7 9304 60 95 7\n", 299 | " 659 825 329 60 518 77 661 7 6604 14 5907 15\n", 300 | " 101 2002 7 4855 832 4317 7 3060 60 95 60 132\n", 301 | " 438 15 3903 3060 7 95 732 234 7 101 661 179\n", 302 | " 3527 7 469 15 101 180 131 4 358 308 367 3\n", 303 | " 36 9 45 846 208 291 5 124 275 289 7 349\n", 304 | " 291 5 97 48 2648 281 7 369 2174 5 21 136\n", 305 | " 30 124 275 889 289 4 228 3 720 11 85 4\n", 306 | " 337 3 237 243 5 10137 5 116 6 83 5 29\n", 307 | " 7 214 5 1005 5 48 6 38 7 252 106 112\n", 308 | " 5 27 9 1342 5 114 6 83 5 29 7 266\n", 309 | " 5 134 9 41 123 29 7 486 5 306 9 41\n", 310 | " 5 223 7 221 5 107 9 42 5 223 5 29\n", 311 | " 7 263 5 887 5 264 7 260 5 174 9 45\n", 312 | " 109 261 5 181 9 39 5 255 5 29 7 459\n", 313 | " 404 221 156 16 9 21 5 85 5 29 7 238\n", 314 | " 5 45 9 41 5 136 6 40 5 29 7 202\n", 315 | " 5 89 9 36 123 29 7 226 227 5 41 5\n", 316 | " 20 6 40 7 258 106 112 5 39 9 32 5\n", 317 | " 116 6 83 7 195 5 33 9 36 5 48 6\n", 318 | " 38 7 280 5 21 9 16 5 20 6 40 7\n", 319 | " 197 5 800 5 20 6 40 7 220 5 16 9\n", 320 | " 36 5 20 6 40 7 225 5 1184 5 48 6\n", 321 | " 38 7 283 60 178 5 39 9 32 5 20 6\n", 322 | " 40 7 324 5 27 9 32 5 20 6 40 7\n", 323 | " 250 249 5 64 5 48 6 38 7 241 5 115\n", 324 | " 5 48 6 38 7 241 5 172 5 48 6 38\n", 325 | " 7 225 5 1039 5 48 6 38 7 283 60 178\n", 326 | " 5 39 9 27 5 20 6 40 5 29 7 258\n", 327 | " 106 112 5 45 9 32 5 116 6 83 7 250\n", 328 | " 249 5 52 5 48 6 38 7 197 5 2597 5\n", 329 | " 20 6 40 5 29 7 252 106 112 5 27 9\n", 330 | " 1299 5 114 6 83 5 29 7 220 5 16 9\n", 331 | " 36 5 20 6 40 7 214 5 867 5 48 6\n", 332 | " 38 7 280 5 21 9 36 5 20 6 40 7\n", 333 | " 324 5 27 9 32 5 20 6 40 7 195 5\n", 334 | " 27 9 42 5 48 6 38 7 226 227 5 41\n", 335 | " 5 20 6 40 7 202 5 174 9 16 123 29\n", 336 | " 7 238 5 52 9 36 5 136 6 40 5 29\n", 337 | " 7 261 5 181 9 45 5 255 5 29 7 260\n", 338 | " 5 253 9 16 109 263 5 887 5 264 7 266\n", 339 | " 5 134 9 21 123 29 7 237 243 5 9035 5\n", 340 | " 116 6 83 5 29 4 38 370 1], shape=(537,), dtype=int64)\n", 341 | "Tensor Target: tf.Tensor([ 80 72 54 ... 2 606 1], shape=(3837,), dtype=int64)\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "tfe = tf.contrib.eager\n", 347 | "\n", 348 | "Modes = tf.estimator.ModeKeys\n", 349 | "\n", 350 | "# We can iterate over our examples by making an iterator and calling next on it.\n", 351 | "eager_iterator = tfe.Iterator(mimic_problem.dataset(Modes.EVAL, DATA_DIR))\n", 352 | "example = eager_iterator.next()\n", 353 | "\n", 354 | "input_tensor = example[\"inputs\"]\n", 355 | "target_tensor = example[\"targets\"]\n", 356 | "\n", 357 | "# The tensors are actually encoded using the generated vocabulary file -- you\n", 358 | "# can inspect the actual vocab file in DATA_DIR.\n", 359 | "print(\"Tensor Input: \" + str(input_tensor))\n", 360 | "print(\"Tensor Target: \" + str(target_tensor))" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": {}, 366 | "source": [ 367 | "Below cell is not executed in order to protect patient privacy. Executing it will show the decoded context data and discharge summary" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "# We use the encoders to decode the tensors to the actual input text.\n", 377 | "input_encoder = mimic_problem.get_feature_encoders(\n", 378 | " data_dir=DATA_DIR)[\"inputs\"]\n", 379 | "target_encoder = mimic_problem.get_feature_encoders(\n", 380 | " data_dir=DATA_DIR)[\"targets\"]\n", 381 | "\n", 382 | "input_decoded = input_encoder.decode(input_tensor.numpy())\n", 383 | "target_decoded = target_encoder.decode(target_tensor.numpy())\n", 384 | "\n", 385 | "print(\"Decoded Input: \" + input_decoded)\n", 386 | "print(\"Decoded Target: \" + target_decoded)" 387 | ] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Python 3", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.7.3" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 2 411 | } 412 | -------------------------------------------------------------------------------- /transformer/01_generate_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setup" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Before we can start using the `tensor2tensor` models, we first have to get our data into a format that `tensor2tensor` can digest. This means defining a custom `Problem` as follows:" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#@title Run this only once - Sets up TF Eager execution.\n", 24 | "\n", 25 | "import tensorflow as tf\n", 26 | "\n", 27 | "# Enable Eager execution - useful for seeing the generated data.\n", 28 | "tf.enable_eager_execution()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "WARNING: Logging before flag parsing goes to stderr.\n", 41 | "W0904 18:18:29.019985 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/expert_utils.py:68: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n", 42 | "\n", 43 | "W0904 18:18:31.300210 140425123104576 lazy_loader.py:50] \n", 44 | "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", 45 | "For more information, please see:\n", 46 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", 47 | " * https://github.com/tensorflow/addons\n", 48 | " * https://github.com/tensorflow/io (for I/O related ops)\n", 49 | "If you depend on functionality not listed there, please file an issue.\n", 50 | "\n", 51 | "W0904 18:18:32.220480 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/metrics_hook.py:28: The name tf.train.SessionRunHook is deprecated. Please use tf.estimator.SessionRunHook instead.\n", 52 | "\n", 53 | "W0904 18:18:32.226704 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/adafactor.py:27: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 54 | "\n", 55 | "W0904 18:18:32.228066 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/multistep_optimizer.py:32: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n", 56 | "\n", 57 | "W0904 18:18:32.244130 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/trainer_lib.py:109: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n", 58 | "\n", 59 | "W0904 18:18:32.245685 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/trainer_lib.py:780: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.\n", 60 | "\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "#@title Setting a random seed.\n", 66 | "\n", 67 | "from tensor2tensor.utils import trainer_lib\n", 68 | "\n", 69 | "# Set a seed so that we have deterministic outputs.\n", 70 | "RANDOM_SEED = 301\n", 71 | "trainer_lib.set_random_seed(RANDOM_SEED)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "#@title Run for setting up directories.\n", 81 | "\n", 82 | "import os\n", 83 | "\n", 84 | "# Setup and create directories.\n", 85 | "DATA_DIR = os.path.expanduser(\"../data/t2t_experiments/transformer/low_resource/full_context/data\")\n", 86 | "OUTPUT_DIR = os.path.expanduser(\"../data/t2t_experiments/transformer/low_resource/full_context/output\")\n", 87 | "TMP_DIR = os.path.expanduser(\"/mnt/\")\n", 88 | "\n", 89 | "# Create them.\n", 90 | "tf.gfile.MakeDirs(DATA_DIR)\n", 91 | "tf.gfile.MakeDirs(OUTPUT_DIR)\n", 92 | "tf.gfile.MakeDirs(TMP_DIR)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "from tensor2tensor.data_generators import problem\n", 102 | "from tensor2tensor.data_generators import text_problems\n", 103 | "from tensor2tensor.utils import registry" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "# Define the problem" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 7, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "@registry.register_problem\n", 120 | "\n", 121 | "class MimicDischargeSummaries(text_problems.Text2TextProblem):\n", 122 | " \n", 123 | " @property\n", 124 | " def is_generate_per_split(self):\n", 125 | " # our data already has pre-existing splits so we return true\n", 126 | " return True\n", 127 | "\n", 128 | " def generate_samples(self, data_dir, tmp_dir, dataset_split):\n", 129 | " \n", 130 | " del tmp_dir\n", 131 | " \n", 132 | " _train = (dataset_split == problem.DatasetSplit.TRAIN)\n", 133 | " _eval = (dataset_split == problem.DatasetSplit.EVAL)\n", 134 | " \n", 135 | " dataset = \"train\" if _train else \"val\" if _eval else \"test\"\n", 136 | " \n", 137 | " full_context = \"full_context\" in str(data_dir) # returns a boolean\n", 138 | " directory = \"../data/preprocessed/low_resource/\"\n", 139 | " tgt = directory + \"tgt-\" + dataset + \".txt\"\n", 140 | "\n", 141 | " if full_context == True:\n", 142 | " src = directory + \"src-\" + dataset + \".txt\"\n", 143 | " else:\n", 144 | " directory += \"other_contexts/\" \n", 145 | " context = str(data_dir)[39:-5] # this index needs to be changed if file paths are changed\n", 146 | " src = directory + \"src-\" + dataset + \"-\" + context + \".txt\"\n", 147 | " \n", 148 | " f_src = open(src,'r')\n", 149 | " f_tgt = open(tgt,'r')\n", 150 | " \n", 151 | " context_data = f_src.readline()\n", 152 | " discharge_summary = f_tgt.readline()\n", 153 | "\n", 154 | " while context_data:\n", 155 | " yield {\n", 156 | " \"inputs\" : context_data,\n", 157 | " \"targets\" : discharge_summary,\n", 158 | " }\n", 159 | " \n", 160 | " context_data = f_src.readline()\n", 161 | " discharge_summary = f_tgt.readline()\n", 162 | " \n", 163 | " f_src.close()\n", 164 | " f_tgt.close()\n", 165 | "\n", 166 | " @property\n", 167 | " def vocab_type(self):\n", 168 | " # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good\n", 169 | " # tradeoff between CHARACTER and TOKEN.\n", 170 | " return text_problems.VocabType.SUBWORD\n", 171 | "\n", 172 | " @property\n", 173 | " def approx_vocab_size(self):\n", 174 | " # Approximate vocab size to generate. Only for VocabType.SUBWORD.\n", 175 | " return 2**15 # ~32k - this is the default setting\n", 176 | "\n", 177 | " @property\n", 178 | " def dataset_splits(self):\n", 179 | " return [{\n", 180 | " \"split\": problem.DatasetSplit.TRAIN,\n", 181 | " \"shards\": 80\n", 182 | " }, {\n", 183 | " \"split\": problem.DatasetSplit.EVAL,\n", 184 | " \"shards\": 10\n", 185 | " }, {\n", 186 | " \"split\": problem.DatasetSplit.TEST,\n", 187 | " \"shards\": 10\n", 188 | " }]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "# Generate the data" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "First, we instantiate the problem and run it for the full context data." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stderr", 212 | "output_type": "stream", 213 | "text": [ 214 | "W0904 18:19:54.778185 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:343: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.\n", 215 | "\n", 216 | "W0904 18:19:54.779488 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:349: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", 217 | "\n", 218 | "W0904 18:20:43.752292 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:355: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.\n", 219 | "\n", 220 | "W0904 18:20:43.754397 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:944: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 221 | "\n", 222 | "W0904 18:20:43.809778 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:164: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", 223 | "\n", 224 | "W0904 18:21:04.638908 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:183: The name tf.gfile.Rename is deprecated. Please use tf.io.gfile.rename instead.\n", 225 | "\n", 226 | "W0904 18:21:43.317981 140425123104576 deprecation.py:323] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:469: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.\n", 227 | "Instructions for updating:\n", 228 | "Use eager execution and: \n", 229 | "`tf.data.TFRecordDataset(path)`\n", 230 | "W0904 18:21:43.325965 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:513: The name tf.gfile.Remove is deprecated. Please use tf.io.gfile.remove instead.\n", 231 | "\n" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "mimic_problem = MimicDischargeSummaries()\n", 237 | "mimic_problem.generate_data(DATA_DIR, TMP_DIR)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "Now, we run it in a loop instead for each individual context type." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 7, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "name": "stderr", 254 | "output_type": "stream", 255 | "text": [ 256 | "W0802 19:05:31.212889 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:343: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.\n", 257 | "\n", 258 | "W0802 19:05:31.213634 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:349: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", 259 | "\n", 260 | "W0802 19:08:39.032947 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:355: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.\n", 261 | "\n", 262 | "W0802 19:08:39.034780 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:944: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 263 | "\n", 264 | "W0802 19:08:39.085365 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:164: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", 265 | "\n", 266 | "W0802 19:11:46.960594 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:183: The name tf.gfile.Rename is deprecated. Please use tf.io.gfile.rename instead.\n", 267 | "\n", 268 | "W0802 19:12:35.876786 139639908783936 deprecation.py:323] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:469: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.\n", 269 | "Instructions for updating:\n", 270 | "Use eager execution and: \n", 271 | "`tf.data.TFRecordDataset(path)`\n", 272 | "W0802 19:12:35.887185 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:513: The name tf.gfile.Remove is deprecated. Please use tf.io.gfile.remove instead.\n", 273 | "\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "context_list = ['h','h-gae','h-gae-d','h-gae-p','h-gae-d-p','h-gae-d-p-m','h-gae-d-p-m-t','h-gae-d-p-m-l']\n", 279 | "\n", 280 | "for context in context_list:\n", 281 | " # Setup and create directories.\n", 282 | " DATA_DIR = os.path.expanduser(\"../data/t2t_experiments/other_contexts/\"+context+\"/data\")\n", 283 | " OUTPUT_DIR = os.path.expanduser(\"../data/t2t_experiments/other_contexts/\"+context+\"/output\")\n", 284 | " TMP_DIR = os.path.expanduser(\"/mnt/\")\n", 285 | "\n", 286 | " # Create them.\n", 287 | " tf.gfile.MakeDirs(DATA_DIR)\n", 288 | " tf.gfile.MakeDirs(OUTPUT_DIR)\n", 289 | " tf.gfile.MakeDirs(TMP_DIR)\n", 290 | " \n", 291 | " mimic_problem.generate_data(DATA_DIR, TMP_DIR)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "# View the generated data" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 9, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stderr", 308 | "output_type": "stream", 309 | "text": [ 310 | "W0904 18:22:27.944880 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_problems.py:394: The name tf.VarLenFeature is deprecated. Please use tf.io.VarLenFeature instead.\n", 311 | "\n", 312 | "W0904 18:22:27.945790 140425123104576 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/problem.py:705: The name tf.FixedLenFeature is deprecated. Please use tf.io.FixedLenFeature instead.\n", 313 | "\n" 314 | ] 315 | }, 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "Tensor Input: tf.Tensor(\n", 321 | "[ 79 71 52 1899 6 43 6 33 19 57 71 429\n", 322 | " 367 3 770 4 364 3 1149 4 365 3 325 4\n", 323 | " 366 3 339 60 2054 94 7 949 139 2915 24 176\n", 324 | " 136 891 1307 7 582 146 149 60 94 7 136 298\n", 325 | " 149 7 94 822 245 7 705 303 741 7 102 591\n", 326 | " 124 15 325 54 820 7 2749 268 741 60 630 444\n", 327 | " 761 176 55 828 891 94 4 362 277 368 3 1156\n", 328 | " 1345 1564 5 21 9 7064 2925 7 1156 1345 3868 5\n", 329 | " 105 9 11939 12298 11326 7 674 5 1825 968 7 1433\n", 330 | " 4825 6989 5 22037 11326 4 235 3 812 11 92 4\n", 331 | " 343 3 227 232 5 11345 5 111 6 81 5 30\n", 332 | " 7 191 5 1098 5 18 6 37 5 30 7 225\n", 333 | " 5 170 5 48 6 38 5 30 7 230 229 5\n", 334 | " 113 5 48 6 38 7 215 5 113 9 26 5\n", 335 | " 223 7 480 5 171 9 35 5 223 7 252 5\n", 336 | " 67 9 45 106 238 104 107 5 33 9 216 5\n", 337 | " 112 6 81 7 273 60 177 5 40 9 45 5\n", 338 | " 18 6 37 7 212 5 941 5 48 6 38 7\n", 339 | " 209 5 16 9 26 5 18 6 37 5 30 7\n", 340 | " 269 5 21 9 16 5 18 6 37 7 250 5\n", 341 | " 431 5 248 7 300 5 26 9 16 5 18 6\n", 342 | " 37 7 187 5 33 9 32 5 48 6 38 7\n", 343 | " 204 5 1851 5 48 6 38 7 213 214 5 135\n", 344 | " 5 18 6 37 7 202 5 1051 9 32 106 234\n", 345 | " 5 113 9 42 5 138 6 37 7 447 414 215\n", 346 | " 162 16 9 21 5 92 7 251 5 246 9 26\n", 347 | " 5 239 5 30 7 247 5 246 9 40 106 243\n", 348 | " 104 107 5 50 9 26 5 111 6 81 4 38\n", 349 | " 369 1], shape=(338,), dtype=int64)\n", 350 | "Tensor Target: tf.Tensor([ 79 71 52 ... 2 599 1], shape=(3113,), dtype=int64)\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "tfe = tf.contrib.eager\n", 356 | "\n", 357 | "Modes = tf.estimator.ModeKeys\n", 358 | "\n", 359 | "# We can iterate over our examples by making an iterator and calling next on it.\n", 360 | "eager_iterator = tfe.Iterator(mimic_problem.dataset(Modes.EVAL, DATA_DIR))\n", 361 | "example = eager_iterator.next()\n", 362 | "\n", 363 | "input_tensor = example[\"inputs\"]\n", 364 | "target_tensor = example[\"targets\"]\n", 365 | "\n", 366 | "# The tensors are actually encoded using the generated vocabulary file -- you\n", 367 | "# can inspect the actual vocab file in DATA_DIR.\n", 368 | "print(\"Tensor Input: \" + str(input_tensor))\n", 369 | "print(\"Tensor Target: \" + str(target_tensor))" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": {}, 375 | "source": [ 376 | "Below cell is not executed in order to protect patient privacy. Executing it will show the decoded context data and discharge summary" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 11, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "# We use the encoders to decode the tensors to the actual input text.\n", 386 | "input_encoder = mimic_problem.get_feature_encoders(\n", 387 | " data_dir=DATA_DIR)[\"inputs\"]\n", 388 | "target_encoder = mimic_problem.get_feature_encoders(\n", 389 | " data_dir=DATA_DIR)[\"targets\"]\n", 390 | "\n", 391 | "input_decoded = input_encoder.decode(input_tensor.numpy())\n", 392 | "target_decoded = target_encoder.decode(target_tensor.numpy())\n", 393 | "\n", 394 | "print(\"Decoded Input: \" + input_decoded)\n", 395 | "print(\"Decoded Target: \" + target_decoded)" 396 | ] 397 | } 398 | ], 399 | "metadata": { 400 | "kernelspec": { 401 | "display_name": "Python [conda env:tf]", 402 | "language": "python", 403 | "name": "conda-env-tf-py" 404 | }, 405 | "language_info": { 406 | "codemirror_mode": { 407 | "name": "ipython", 408 | "version": 3 409 | }, 410 | "file_extension": ".py", 411 | "mimetype": "text/x-python", 412 | "name": "python", 413 | "nbconvert_exporter": "python", 414 | "pygments_lexer": "ipython3", 415 | "version": "3.7.3" 416 | } 417 | }, 418 | "nbformat": 4, 419 | "nbformat_minor": 2 420 | } 421 | -------------------------------------------------------------------------------- /transformer-xl/run_lm_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on WikiText-2 (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import, division, print_function 23 | 24 | import argparse 25 | import glob 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | 31 | import numpy as np 32 | import torch 33 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler 34 | from torch.utils.data.distributed import DistributedSampler 35 | from tensorboardX import SummaryWriter 36 | from tqdm import tqdm, trange 37 | 38 | from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, 39 | BertConfig, BertForMaskedLM, BertTokenizer, 40 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, 41 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, 42 | TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLTokenizer) 43 | 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | MODEL_CLASSES = { 49 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 50 | 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 51 | 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), 52 | 'transfo-xl-wt103':(TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLTokenizer) 53 | } 54 | 55 | 56 | class TextDataset(Dataset): 57 | def __init__(self, tokenizer, file_path='train', block_size=512): 58 | assert os.path.isfile(file_path) 59 | directory, filename = os.path.split(file_path) 60 | cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}') 61 | 62 | if os.path.exists(cached_features_file): 63 | logger.info("Loading features from cached file %s", cached_features_file) 64 | with open(cached_features_file, 'rb') as handle: 65 | self.examples = pickle.load(handle) 66 | else: 67 | logger.info("Creating features from dataset file at %s", directory) 68 | 69 | self.examples = [] 70 | with open(file_path, encoding="utf-8") as f: 71 | text = f.read() 72 | 73 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 74 | 75 | while len(tokenized_text) >= block_size: # Truncate in block of block_size 76 | self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size])) 77 | tokenized_text = tokenized_text[block_size:] 78 | # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) 79 | # If your dataset is small, first you should loook for a bigger one :-) and second you 80 | # can change this behavior by adding (model specific) padding. 81 | 82 | logger.info("Saving features into cached file %s", cached_features_file) 83 | with open(cached_features_file, 'wb') as handle: 84 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 85 | 86 | def __len__(self): 87 | return len(self.examples) 88 | 89 | def __getitem__(self, item): 90 | return torch.tensor(self.examples[item]) 91 | 92 | 93 | def load_and_cache_examples(args, tokenizer, evaluate=False): 94 | dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) 95 | return dataset 96 | 97 | 98 | def set_seed(args): 99 | random.seed(args.seed) 100 | np.random.seed(args.seed) 101 | torch.manual_seed(args.seed) 102 | if args.n_gpu > 0: 103 | torch.cuda.manual_seed_all(args.seed) 104 | 105 | 106 | def mask_tokens(inputs, tokenizer, args): 107 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ 108 | labels = inputs.clone() 109 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 110 | masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool() 111 | labels[~masked_indices] = -1 # We only compute loss on masked tokens 112 | 113 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 114 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 115 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 116 | 117 | # 10% of the time, we replace masked input tokens with random word 118 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 119 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) 120 | inputs[indices_random] = random_words[indices_random] 121 | 122 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 123 | return inputs, labels 124 | 125 | 126 | def train(args, train_dataset, model, tokenizer): 127 | """ Train the model """ 128 | if args.local_rank in [-1, 0]: 129 | tb_writer = SummaryWriter() 130 | 131 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 132 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 133 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 134 | 135 | if args.max_steps > 0: 136 | t_total = args.max_steps 137 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 138 | else: 139 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 140 | 141 | # Prepare optimizer and schedule (linear warmup and decay) 142 | no_decay = ['bias', 'LayerNorm.weight'] 143 | optimizer_grouped_parameters = [ 144 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 145 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 146 | ] 147 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 148 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 149 | if args.fp16: 150 | try: 151 | from apex import amp 152 | except ImportError: 153 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 154 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 155 | 156 | # multi-gpu training (should be after apex fp16 initialization) 157 | if args.n_gpu > 1: 158 | model = torch.nn.DataParallel(model) 159 | 160 | # Distributed training (should be after apex fp16 initialization) 161 | if args.local_rank != -1: 162 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 163 | output_device=args.local_rank, 164 | find_unused_parameters=True) 165 | 166 | # Train! 167 | logger.info("***** Running training *****") 168 | logger.info(" Num examples = %d", len(train_dataset)) 169 | logger.info(" Num Epochs = %d", args.num_train_epochs) 170 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 171 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 172 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 173 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 174 | logger.info(" Total optimization steps = %d", t_total) 175 | 176 | global_step = 0 177 | tr_loss, logging_loss = 0.0, 0.0 178 | model.zero_grad() 179 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 180 | set_seed(args) # Added here for reproducibility (even between python 2 and 3) 181 | for _ in train_iterator: 182 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 183 | for step, batch in enumerate(epoch_iterator): 184 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 185 | inputs = inputs.to(args.device) 186 | labels = labels.to(args.device) 187 | model.train() 188 | outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) 189 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 190 | 191 | if args.n_gpu > 1: 192 | loss = loss.mean() # mean() to average on multi-gpu parallel training 193 | if args.gradient_accumulation_steps > 1: 194 | loss = loss / args.gradient_accumulation_steps 195 | 196 | if args.fp16: 197 | with amp.scale_loss(loss, optimizer) as scaled_loss: 198 | scaled_loss.backward() 199 | else: 200 | loss.backward() 201 | 202 | tr_loss += loss.item() 203 | if (step + 1) % args.gradient_accumulation_steps == 0: 204 | if args.fp16: 205 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 206 | else: 207 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 208 | optimizer.step() 209 | scheduler.step() # Update learning rate schedule 210 | model.zero_grad() 211 | global_step += 1 212 | 213 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 214 | # Log metrics 215 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 216 | results = evaluate(args, model, tokenizer) 217 | for key, value in results.items(): 218 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 219 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 220 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 221 | logging_loss = tr_loss 222 | 223 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 224 | # Save model checkpoint 225 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 226 | if not os.path.exists(output_dir): 227 | os.makedirs(output_dir) 228 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 229 | model_to_save.save_pretrained(output_dir) 230 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 231 | logger.info("Saving model checkpoint to %s", output_dir) 232 | 233 | if args.max_steps > 0 and global_step > args.max_steps: 234 | epoch_iterator.close() 235 | break 236 | if args.max_steps > 0 and global_step > args.max_steps: 237 | train_iterator.close() 238 | break 239 | 240 | if args.local_rank in [-1, 0]: 241 | tb_writer.close() 242 | 243 | return global_step, tr_loss / global_step 244 | 245 | 246 | def evaluate(args, model, tokenizer, prefix=""): 247 | # Loop to handle MNLI double evaluation (matched, mis-matched) 248 | eval_output_dir = args.output_dir 249 | 250 | results = {} 251 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) 252 | 253 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 254 | os.makedirs(eval_output_dir) 255 | 256 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 257 | # Note that DistributedSampler samples randomly 258 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 259 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 260 | 261 | # Eval! 262 | logger.info("***** Running evaluation {} *****".format(prefix)) 263 | logger.info(" Num examples = %d", len(eval_dataset)) 264 | logger.info(" Batch size = %d", args.eval_batch_size) 265 | eval_loss = 0.0 266 | nb_eval_steps = 0 267 | model.eval() 268 | 269 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 270 | batch = batch.to(args.device) 271 | 272 | with torch.no_grad(): 273 | outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch) 274 | lm_loss = outputs[0] 275 | eval_loss += lm_loss.mean().item() 276 | nb_eval_steps += 1 277 | 278 | eval_loss = eval_loss / nb_eval_steps 279 | perplexity = torch.exp(torch.tensor(eval_loss)) 280 | 281 | result = { 282 | "perplexity": perplexity 283 | } 284 | 285 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 286 | with open(output_eval_file, "w") as writer: 287 | logger.info("***** Eval results {} *****".format(prefix)) 288 | for key in sorted(result.keys()): 289 | logger.info(" %s = %s", key, str(result[key])) 290 | writer.write("%s = %s\n" % (key, str(result[key]))) 291 | 292 | return results 293 | 294 | 295 | def main(): 296 | parser = argparse.ArgumentParser() 297 | 298 | ## Required parameters 299 | parser.add_argument("--train_data_file", default=None, type=str, required=True, 300 | help="The input training data file (a text file).") 301 | parser.add_argument("--output_dir", default=None, type=str, required=True, 302 | help="The output directory where the model predictions and checkpoints will be written.") 303 | 304 | ## Other parameters 305 | parser.add_argument("--eval_data_file", default=None, type=str, 306 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 307 | 308 | parser.add_argument("--model_type", default="bert", type=str, 309 | help="The model architecture to be fine-tuned.") 310 | parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str, 311 | help="The model checkpoint for weights initialization.") 312 | 313 | parser.add_argument("--mlm", action='store_true', 314 | help="Train with masked-language modeling loss instead of language modeling.") 315 | parser.add_argument("--mlm_probability", type=float, default=0.15, 316 | help="Ratio of tokens to mask for masked language modeling loss") 317 | 318 | parser.add_argument("--config_name", default="", type=str, 319 | help="Optional pretrained config name or path if not the same as model_name_or_path") 320 | parser.add_argument("--tokenizer_name", default="", type=str, 321 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 322 | parser.add_argument("--cache_dir", default="", type=str, 323 | help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") 324 | parser.add_argument("--block_size", default=-1, type=int, 325 | help="Optional input sequence length after tokenization." 326 | "The training dataset will be truncated in block of this size for training." 327 | "Default to the model max input length for single sentence inputs (take into account special tokens).") 328 | parser.add_argument("--do_train", action='store_true', 329 | help="Whether to run training.") 330 | parser.add_argument("--do_eval", action='store_true', 331 | help="Whether to run eval on the dev set.") 332 | parser.add_argument("--evaluate_during_training", action='store_true', 333 | help="Run evaluation during training at each logging step.") 334 | parser.add_argument("--do_lower_case", action='store_true', 335 | help="Set this flag if you are using an uncased model.") 336 | 337 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 338 | help="Batch size per GPU/CPU for training.") 339 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, 340 | help="Batch size per GPU/CPU for evaluation.") 341 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 342 | help="Number of updates steps to accumulate before performing a backward/update pass.") 343 | parser.add_argument("--learning_rate", default=5e-5, type=float, 344 | help="The initial learning rate for Adam.") 345 | parser.add_argument("--weight_decay", default=0.0, type=float, 346 | help="Weight deay if we apply some.") 347 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 348 | help="Epsilon for Adam optimizer.") 349 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 350 | help="Max gradient norm.") 351 | parser.add_argument("--num_train_epochs", default=1.0, type=float, 352 | help="Total number of training epochs to perform.") 353 | parser.add_argument("--max_steps", default=-1, type=int, 354 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 355 | parser.add_argument("--warmup_steps", default=0, type=int, 356 | help="Linear warmup over warmup_steps.") 357 | 358 | parser.add_argument('--logging_steps', type=int, default=50, 359 | help="Log every X updates steps.") 360 | parser.add_argument('--save_steps', type=int, default=50, 361 | help="Save checkpoint every X updates steps.") 362 | parser.add_argument("--eval_all_checkpoints", action='store_true', 363 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") 364 | parser.add_argument("--no_cuda", action='store_true', 365 | help="Avoid using CUDA when available") 366 | parser.add_argument('--overwrite_output_dir', action='store_true', 367 | help="Overwrite the content of the output directory") 368 | parser.add_argument('--overwrite_cache', action='store_true', 369 | help="Overwrite the cached training and evaluation sets") 370 | parser.add_argument('--seed', type=int, default=42, 371 | help="random seed for initialization") 372 | 373 | parser.add_argument('--fp16', action='store_true', 374 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 375 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 376 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 377 | "See details at https://nvidia.github.io/apex/amp.html") 378 | parser.add_argument("--local_rank", type=int, default=-1, 379 | help="For distributed training: local_rank") 380 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 381 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 382 | args = parser.parse_args() 383 | 384 | if args.model_type in ["bert", "roberta"] and not args.mlm: 385 | raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " 386 | "flag (masked language modeling).") 387 | if args.eval_data_file is None and args.do_eval: 388 | raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 389 | "or remove the --do_eval argument.") 390 | 391 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 392 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 393 | 394 | # Setup distant debugging if needed 395 | if args.server_ip and args.server_port: 396 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 397 | import ptvsd 398 | print("Waiting for debugger attach") 399 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 400 | ptvsd.wait_for_attach() 401 | 402 | # Setup CUDA, GPU & distributed training 403 | if args.local_rank == -1 or args.no_cuda: 404 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 405 | args.n_gpu = torch.cuda.device_count() 406 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 407 | torch.cuda.set_device(args.local_rank) 408 | device = torch.device("cuda", args.local_rank) 409 | torch.distributed.init_process_group(backend='nccl') 410 | args.n_gpu = 1 411 | args.device = device 412 | 413 | # Setup logging 414 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 415 | datefmt = '%m/%d/%Y %H:%M:%S', 416 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 417 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 418 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 419 | 420 | # Set seed 421 | set_seed(args) 422 | 423 | # Load pretrained model and tokenizer 424 | if args.local_rank not in [-1, 0]: 425 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab 426 | 427 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 428 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 429 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) 430 | if args.block_size <= 0: 431 | args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model 432 | args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) 433 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) 434 | model.to(args.device) 435 | 436 | if args.local_rank == 0: 437 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab 438 | 439 | logger.info("Training/evaluation parameters %s", args) 440 | 441 | # Training 442 | if args.do_train: 443 | if args.local_rank not in [-1, 0]: 444 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache 445 | 446 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) 447 | 448 | if args.local_rank == 0: 449 | torch.distributed.barrier() 450 | 451 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 452 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 453 | 454 | 455 | # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() 456 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 457 | # Create output directory if needed 458 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 459 | os.makedirs(args.output_dir) 460 | 461 | logger.info("Saving model checkpoint to %s", args.output_dir) 462 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 463 | # They can then be reloaded using `from_pretrained()` 464 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 465 | model_to_save.save_pretrained(args.output_dir) 466 | tokenizer.save_pretrained(args.output_dir) 467 | 468 | # Good practice: save your training arguments together with the trained model 469 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 470 | 471 | # Load a trained model and vocabulary that you have fine-tuned 472 | model = model_class.from_pretrained(args.output_dir) 473 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 474 | model.to(args.device) 475 | 476 | 477 | # Evaluation 478 | results = {} 479 | if args.do_eval and args.local_rank in [-1, 0]: 480 | checkpoints = [args.output_dir] 481 | if args.eval_all_checkpoints: 482 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 483 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 484 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 485 | for checkpoint in checkpoints: 486 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 487 | model = model_class.from_pretrained(checkpoint) 488 | model.to(args.device) 489 | result = evaluate(args, model, tokenizer, prefix=global_step) 490 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 491 | results.update(result) 492 | 493 | return results 494 | 495 | 496 | if __name__ == "__main__": 497 | main() -------------------------------------------------------------------------------- /transformer-xl/03_decode.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from pytorch_transformers import TransfoXLTokenizer, TransfoXLModel, TransfoXLLMHeadModel\n", 11 | "import subprocess\n", 12 | "\n", 13 | "import logging\n", 14 | "logging.basicConfig(level=logging.INFO)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "b'admission date : [ 2138/7/6 ] discharge date : [ F 69 black candidal esophagitis | pneumonia, organism unspecified | intestinal infection due to clostridium difficile | chronic obstructive asthma, unspecified | congestive heart failure, unspecified | chronic kidney disease, unspecified | hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage i through stage iv, or unspecified | poisoning by insulins and antidiabetic agents | accidental poisoning by hormones and synthetic substitutes | diabetes with neurological manifestations, type ii or unspecified type, not stated as uncontrolled | gastroparesis | diabetes with other specified manifestations, type ii or unspecified type, not stated as uncontrolled | long-term (current) use of insulin | personal history of venous thrombosis and embolism | unspecified gastritis and gastroduodenitis, without mention of hemorrhage | anemia, unspecified | pure hypercholesterolemia | fall from other slipping, tripping, or stumbling | osteoarthrosis, unspecified whether generalized or localized, lower leg | home accidents | personal history of tobacco use | other late effects of cerebrovascular disease, facial weakness | family history of asthma | family history of diabetes mellitus | family history of malignant neoplasm of gastrointestinal tract | family history of malignant neoplasm of other respiratory and intrathoracic organs

desipramine hcl , 10 mg Tab | diltiazem extended-release , 120 mg ER Cap blood culture : None | catheter tip-iv : None | sputum : pseudomonas aeruginosa | urine : pseudomonas aeruginosa Red Blood Cells , 3.43 , m/uL , abnormal | Phosphate , 2.8 , mg/dL | RDW , 16.9 , % , abnormal | Anion Gap , 13 , mEq/L | Bicarbonate , 27 , mEq/L | Calcium, Total , 9.4 , mg/dL | Chloride , 102 , mEq/L | Creatinine , 1.6 , mg/dL , abnormal | Glucose , 94 , mg/dL | Magnesium , 2.3 , mg/dL | White Blood Cells , 9.8 , K/uL | Potassium , 4.2 , mEq/L | MCH , 29.3 , pg | Sodium , 138 , mEq/L | MCV , 90 , fL | MCHC , 32.7 , % | Platelet Count , 359 , K/uL | Hemoglobin , 10.1 , g/dL , abnormal | Hematocrit , 30.7 , % , abnormal | Urea Nitrogen , 19 , mg/dL | MCV , 88 , fL | Phosphate , 2.8 , mg/dL | MCHC , 33.8 , % | MCH , 29.9 , pg | Hemoglobin , 10.0 , g/dL , abnormal | Hematocrit , 29.4 , % , abnormal | Urea Nitrogen , 20 , mg/dL | Sodium , 140 , mEq/L | Potassium , 4.0 , mEq/L | Calcium, Total , 9.2 , mg/dL | Magnesium , 2.2 , mg/dL | Glucose , 95 , mg/dL | Creatinine , 1.6 , mg/dL , abnormal | Chloride , 104 , mEq/L | Bicarbonate , 24 , mEq/L | Anion Gap , 16 , mEq/L | Platelet Count , 304 , K/uL | RDW , 17.0 , % , abnormal | White Blood Cells , 9.5 , K/uL | Red Blood Cells , 3.33 = \\n'\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "proc = subprocess.Popen([\"head -1 ../data/transformer-xl/low_resource/test-input-text.txt\"], stdout=subprocess.PIPE, shell=True)\n", 32 | "(out, err) = proc.communicate()\n", 33 | "print(out)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "INFO:pytorch_transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin from cache at /home/aa5118/.cache/torch/pytorch_transformers/b24cb708726fd43cbf1a382da9ed3908263e4fb8a156f9e0a4f45b7540c69caa.a6a9c41b856e5c31c9f125dd6a7ed4b833fbcefda148b627871d4171b25cffd1\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 10, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "out=\"I'm now a Professor of Applied Quantitative Analysis in the Analysis, Engineering, Simulation and Optimization of Performance (AESOP) group in the Department of Computing at Imperial. My broad area of research interest is the application of mathematical modelling techniques to real life systems. Specific areas of interest include, but are not limited to, modelling and optimisation in parallel queueing systems (especially split-merge and fork-join systems), modelling of storage systems, stochastic modelling of sport, stochastic modelling of healthcare systems, resource allocation and control in cloud-computing environments, numerical solution of (semi-)Markov models and specification techniques for SLA specification, compliance prediction and monitoring.\"\n", 60 | "\n", 61 | "tokenized_text_1 = tokenizer.tokenize(str(out))\n", 62 | "\n", 63 | "# Convert token to vocabulary indices\n", 64 | "indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)\n", 65 | "\n", 66 | "# Convert inputs to PyTorch tensors\n", 67 | "tokens_tensor_1 = torch.tensor([indexed_tokens_1])\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stderr", 77 | "output_type": "stream", 78 | "text": [ 79 | "INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json from cache at /home/aa5118/.cache/torch/pytorch_transformers/a6dfd6a3896b3ae4c1a3c5f26ff1f1827c26c15b679de9212a04060eaf1237df.aef76fb1064c932cd6a2a2be3f23ebbfa5f9b6e29e8e87b571c45b4a5d5d1b90\n", 80 | "INFO:pytorch_transformers.modeling_utils:Model config {\n", 81 | " \"adaptive\": true,\n", 82 | " \"attn_type\": 0,\n", 83 | " \"clamp_len\": 1000,\n", 84 | " \"cutoffs\": [\n", 85 | " 20000,\n", 86 | " 40000,\n", 87 | " 200000\n", 88 | " ],\n", 89 | " \"d_embed\": 1024,\n", 90 | " \"d_head\": 64,\n", 91 | " \"d_inner\": 4096,\n", 92 | " \"d_model\": 1024,\n", 93 | " \"div_val\": 4,\n", 94 | " \"dropatt\": 0.0,\n", 95 | " \"dropout\": 0.1,\n", 96 | " \"ext_len\": 0,\n", 97 | " \"finetuning_task\": null,\n", 98 | " \"init\": \"normal\",\n", 99 | " \"init_range\": 0.01,\n", 100 | " \"init_std\": 0.02,\n", 101 | " \"mem_len\": 1600,\n", 102 | " \"n_head\": 16,\n", 103 | " \"n_layer\": 18,\n", 104 | " \"n_token\": 267735,\n", 105 | " \"num_labels\": 2,\n", 106 | " \"output_attentions\": false,\n", 107 | " \"output_hidden_states\": false,\n", 108 | " \"pre_lnorm\": false,\n", 109 | " \"proj_init_std\": 0.01,\n", 110 | " \"pruned_heads\": {},\n", 111 | " \"same_length\": true,\n", 112 | " \"sample_softmax\": -1,\n", 113 | " \"tgt_len\": 128,\n", 114 | " \"tie_projs\": [\n", 115 | " false,\n", 116 | " true,\n", 117 | " true,\n", 118 | " true\n", 119 | " ],\n", 120 | " \"tie_weight\": true,\n", 121 | " \"torchscript\": false,\n", 122 | " \"untie_r\": true\n", 123 | "}\n", 124 | "\n", 125 | "INFO:pytorch_transformers.modeling_utils:loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin from cache at /home/aa5118/.cache/torch/pytorch_transformers/12642ff7d0279757d8356bfd86a729d9697018a0c93ad042de1d0d2cc17fd57b.e9704971f27275ec067a00a67e6a5f0b05b4306b3f714a96e9f763d8fb612671\n" 126 | ] 127 | }, 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "TransfoXLLMHeadModel(\n", 132 | " (transformer): TransfoXLModel(\n", 133 | " (word_emb): AdaptiveEmbedding(\n", 134 | " (emb_layers): ModuleList(\n", 135 | " (0): Embedding(20000, 1024)\n", 136 | " (1): Embedding(20000, 256)\n", 137 | " (2): Embedding(160000, 64)\n", 138 | " (3): Embedding(67735, 16)\n", 139 | " )\n", 140 | " (emb_projs): ParameterList(\n", 141 | " (0): Parameter containing: [torch.FloatTensor of size 1024x1024]\n", 142 | " (1): Parameter containing: [torch.FloatTensor of size 1024x256]\n", 143 | " (2): Parameter containing: [torch.FloatTensor of size 1024x64]\n", 144 | " (3): Parameter containing: [torch.FloatTensor of size 1024x16]\n", 145 | " )\n", 146 | " )\n", 147 | " (drop): Dropout(p=0.1, inplace=False)\n", 148 | " (layers): ModuleList(\n", 149 | " (0): RelPartialLearnableDecoderLayer(\n", 150 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 151 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 152 | " (drop): Dropout(p=0.1, inplace=False)\n", 153 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 154 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 155 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 156 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 157 | " )\n", 158 | " (pos_ff): PositionwiseFF(\n", 159 | " (CoreNet): Sequential(\n", 160 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 161 | " (1): ReLU(inplace=True)\n", 162 | " (2): Dropout(p=0.1, inplace=False)\n", 163 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 164 | " (4): Dropout(p=0.1, inplace=False)\n", 165 | " )\n", 166 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 167 | " )\n", 168 | " )\n", 169 | " (1): RelPartialLearnableDecoderLayer(\n", 170 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 171 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 172 | " (drop): Dropout(p=0.1, inplace=False)\n", 173 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 174 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 175 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 176 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 177 | " )\n", 178 | " (pos_ff): PositionwiseFF(\n", 179 | " (CoreNet): Sequential(\n", 180 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 181 | " (1): ReLU(inplace=True)\n", 182 | " (2): Dropout(p=0.1, inplace=False)\n", 183 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 184 | " (4): Dropout(p=0.1, inplace=False)\n", 185 | " )\n", 186 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 187 | " )\n", 188 | " )\n", 189 | " (2): RelPartialLearnableDecoderLayer(\n", 190 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 191 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 192 | " (drop): Dropout(p=0.1, inplace=False)\n", 193 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 194 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 195 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 196 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 197 | " )\n", 198 | " (pos_ff): PositionwiseFF(\n", 199 | " (CoreNet): Sequential(\n", 200 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 201 | " (1): ReLU(inplace=True)\n", 202 | " (2): Dropout(p=0.1, inplace=False)\n", 203 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 204 | " (4): Dropout(p=0.1, inplace=False)\n", 205 | " )\n", 206 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 207 | " )\n", 208 | " )\n", 209 | " (3): RelPartialLearnableDecoderLayer(\n", 210 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 211 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 212 | " (drop): Dropout(p=0.1, inplace=False)\n", 213 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 214 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 215 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 216 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 217 | " )\n", 218 | " (pos_ff): PositionwiseFF(\n", 219 | " (CoreNet): Sequential(\n", 220 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 221 | " (1): ReLU(inplace=True)\n", 222 | " (2): Dropout(p=0.1, inplace=False)\n", 223 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 224 | " (4): Dropout(p=0.1, inplace=False)\n", 225 | " )\n", 226 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 227 | " )\n", 228 | " )\n", 229 | " (4): RelPartialLearnableDecoderLayer(\n", 230 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 231 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 232 | " (drop): Dropout(p=0.1, inplace=False)\n", 233 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 234 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 235 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 236 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 237 | " )\n", 238 | " (pos_ff): PositionwiseFF(\n", 239 | " (CoreNet): Sequential(\n", 240 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 241 | " (1): ReLU(inplace=True)\n", 242 | " (2): Dropout(p=0.1, inplace=False)\n", 243 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 244 | " (4): Dropout(p=0.1, inplace=False)\n", 245 | " )\n", 246 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 247 | " )\n", 248 | " )\n", 249 | " (5): RelPartialLearnableDecoderLayer(\n", 250 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 251 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 252 | " (drop): Dropout(p=0.1, inplace=False)\n", 253 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 254 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 255 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 256 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 257 | " )\n", 258 | " (pos_ff): PositionwiseFF(\n", 259 | " (CoreNet): Sequential(\n", 260 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 261 | " (1): ReLU(inplace=True)\n", 262 | " (2): Dropout(p=0.1, inplace=False)\n", 263 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 264 | " (4): Dropout(p=0.1, inplace=False)\n", 265 | " )\n", 266 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 267 | " )\n", 268 | " )\n", 269 | " (6): RelPartialLearnableDecoderLayer(\n", 270 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 271 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 272 | " (drop): Dropout(p=0.1, inplace=False)\n", 273 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 274 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 275 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 276 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 277 | " )\n", 278 | " (pos_ff): PositionwiseFF(\n", 279 | " (CoreNet): Sequential(\n", 280 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 281 | " (1): ReLU(inplace=True)\n", 282 | " (2): Dropout(p=0.1, inplace=False)\n", 283 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 284 | " (4): Dropout(p=0.1, inplace=False)\n", 285 | " )\n", 286 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 287 | " )\n", 288 | " )\n", 289 | " (7): RelPartialLearnableDecoderLayer(\n", 290 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 291 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 292 | " (drop): Dropout(p=0.1, inplace=False)\n", 293 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 294 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 295 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 296 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 297 | " )\n", 298 | " (pos_ff): PositionwiseFF(\n", 299 | " (CoreNet): Sequential(\n", 300 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 301 | " (1): ReLU(inplace=True)\n", 302 | " (2): Dropout(p=0.1, inplace=False)\n", 303 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 304 | " (4): Dropout(p=0.1, inplace=False)\n", 305 | " )\n", 306 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 307 | " )\n", 308 | " )\n", 309 | " (8): RelPartialLearnableDecoderLayer(\n", 310 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 311 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 312 | " (drop): Dropout(p=0.1, inplace=False)\n", 313 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 314 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 315 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 316 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 317 | " )\n", 318 | " (pos_ff): PositionwiseFF(\n", 319 | " (CoreNet): Sequential(\n", 320 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 321 | " (1): ReLU(inplace=True)\n", 322 | " (2): Dropout(p=0.1, inplace=False)\n", 323 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 324 | " (4): Dropout(p=0.1, inplace=False)\n", 325 | " )\n", 326 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 327 | " )\n", 328 | " )\n", 329 | " (9): RelPartialLearnableDecoderLayer(\n", 330 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 331 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 332 | " (drop): Dropout(p=0.1, inplace=False)\n", 333 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 334 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 335 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 336 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 337 | " )\n", 338 | " (pos_ff): PositionwiseFF(\n", 339 | " (CoreNet): Sequential(\n", 340 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 341 | " (1): ReLU(inplace=True)\n", 342 | " (2): Dropout(p=0.1, inplace=False)\n", 343 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 344 | " (4): Dropout(p=0.1, inplace=False)\n", 345 | " )\n", 346 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 347 | " )\n", 348 | " )\n", 349 | " (10): RelPartialLearnableDecoderLayer(\n", 350 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 351 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 352 | " (drop): Dropout(p=0.1, inplace=False)\n", 353 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 354 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 355 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 356 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 357 | " )\n", 358 | " (pos_ff): PositionwiseFF(\n", 359 | " (CoreNet): Sequential(\n", 360 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 361 | " (1): ReLU(inplace=True)\n", 362 | " (2): Dropout(p=0.1, inplace=False)\n", 363 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 364 | " (4): Dropout(p=0.1, inplace=False)\n", 365 | " )\n", 366 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 367 | " )\n", 368 | " )\n", 369 | " (11): RelPartialLearnableDecoderLayer(\n", 370 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 371 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 372 | " (drop): Dropout(p=0.1, inplace=False)\n", 373 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 374 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 375 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 376 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 377 | " )\n", 378 | " (pos_ff): PositionwiseFF(\n", 379 | " (CoreNet): Sequential(\n", 380 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 381 | " (1): ReLU(inplace=True)\n", 382 | " (2): Dropout(p=0.1, inplace=False)\n", 383 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 384 | " (4): Dropout(p=0.1, inplace=False)\n", 385 | " )\n", 386 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 387 | " )\n", 388 | " )\n", 389 | " (12): RelPartialLearnableDecoderLayer(\n", 390 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 391 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 392 | " (drop): Dropout(p=0.1, inplace=False)\n", 393 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 394 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 395 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 396 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 397 | " )\n", 398 | " (pos_ff): PositionwiseFF(\n", 399 | " (CoreNet): Sequential(\n", 400 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 401 | " (1): ReLU(inplace=True)\n", 402 | " (2): Dropout(p=0.1, inplace=False)\n", 403 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 404 | " (4): Dropout(p=0.1, inplace=False)\n", 405 | " )\n", 406 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 407 | " )\n", 408 | " )\n", 409 | " (13): RelPartialLearnableDecoderLayer(\n", 410 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 411 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 412 | " (drop): Dropout(p=0.1, inplace=False)\n", 413 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 414 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 415 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 416 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 417 | " )\n", 418 | " (pos_ff): PositionwiseFF(\n", 419 | " (CoreNet): Sequential(\n", 420 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 421 | " (1): ReLU(inplace=True)\n", 422 | " (2): Dropout(p=0.1, inplace=False)\n", 423 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 424 | " (4): Dropout(p=0.1, inplace=False)\n", 425 | " )\n", 426 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 427 | " )\n", 428 | " )\n", 429 | " (14): RelPartialLearnableDecoderLayer(\n", 430 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 431 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 432 | " (drop): Dropout(p=0.1, inplace=False)\n", 433 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 434 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 435 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 436 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 437 | " )\n", 438 | " (pos_ff): PositionwiseFF(\n", 439 | " (CoreNet): Sequential(\n", 440 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 441 | " (1): ReLU(inplace=True)\n", 442 | " (2): Dropout(p=0.1, inplace=False)\n", 443 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 444 | " (4): Dropout(p=0.1, inplace=False)\n", 445 | " )\n", 446 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 447 | " )\n", 448 | " )\n", 449 | " (15): RelPartialLearnableDecoderLayer(\n", 450 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 451 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 452 | " (drop): Dropout(p=0.1, inplace=False)\n", 453 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 454 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 455 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 456 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 457 | " )\n", 458 | " (pos_ff): PositionwiseFF(\n", 459 | " (CoreNet): Sequential(\n", 460 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 461 | " (1): ReLU(inplace=True)\n", 462 | " (2): Dropout(p=0.1, inplace=False)\n", 463 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 464 | " (4): Dropout(p=0.1, inplace=False)\n", 465 | " )\n", 466 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 467 | " )\n", 468 | " )\n", 469 | " (16): RelPartialLearnableDecoderLayer(\n", 470 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 471 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 472 | " (drop): Dropout(p=0.1, inplace=False)\n", 473 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 474 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 475 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 476 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 477 | " )\n", 478 | " (pos_ff): PositionwiseFF(\n", 479 | " (CoreNet): Sequential(\n", 480 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 481 | " (1): ReLU(inplace=True)\n", 482 | " (2): Dropout(p=0.1, inplace=False)\n", 483 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 484 | " (4): Dropout(p=0.1, inplace=False)\n", 485 | " )\n", 486 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 487 | " )\n", 488 | " )\n", 489 | " (17): RelPartialLearnableDecoderLayer(\n", 490 | " (dec_attn): RelPartialLearnableMultiHeadAttn(\n", 491 | " (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)\n", 492 | " (drop): Dropout(p=0.1, inplace=False)\n", 493 | " (dropatt): Dropout(p=0.0, inplace=False)\n", 494 | " (o_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 495 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 496 | " (r_net): Linear(in_features=1024, out_features=1024, bias=False)\n", 497 | " )\n", 498 | " (pos_ff): PositionwiseFF(\n", 499 | " (CoreNet): Sequential(\n", 500 | " (0): Linear(in_features=1024, out_features=4096, bias=True)\n", 501 | " (1): ReLU(inplace=True)\n", 502 | " (2): Dropout(p=0.1, inplace=False)\n", 503 | " (3): Linear(in_features=4096, out_features=1024, bias=True)\n", 504 | " (4): Dropout(p=0.1, inplace=False)\n", 505 | " )\n", 506 | " (layer_norm): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)\n", 507 | " )\n", 508 | " )\n", 509 | " )\n", 510 | " (pos_emb): PositionalEmbedding()\n", 511 | " )\n", 512 | " (crit): ProjectedAdaptiveLogSoftmax(\n", 513 | " (out_layers): ModuleList(\n", 514 | " (0): Linear(in_features=1024, out_features=20000, bias=True)\n", 515 | " (1): Linear(in_features=256, out_features=20000, bias=True)\n", 516 | " (2): Linear(in_features=64, out_features=160000, bias=True)\n", 517 | " (3): Linear(in_features=16, out_features=67735, bias=True)\n", 518 | " )\n", 519 | " (out_projs): ParameterList(\n", 520 | " (0): Parameter containing: [torch.FloatTensor of size 1024x1024]\n", 521 | " (1): Parameter containing: [torch.FloatTensor of size 1024x256]\n", 522 | " (2): Parameter containing: [torch.FloatTensor of size 1024x64]\n", 523 | " (3): Parameter containing: [torch.FloatTensor of size 1024x16]\n", 524 | " )\n", 525 | " )\n", 526 | ")" 527 | ] 528 | }, 529 | "execution_count": 11, 530 | "metadata": {}, 531 | "output_type": "execute_result" 532 | } 533 | ], 534 | "source": [ 535 | "model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')\n", 536 | "model.eval()" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 12, 542 | "metadata": {}, 543 | "outputs": [ 544 | { 545 | "name": "stdout", 546 | "output_type": "stream", 547 | "text": [ 548 | "of\n", 549 | "the\n", 550 | "software\n", 551 | ".\n", 552 | "The\n", 553 | "use\n", 554 | "and\n", 555 | "application\n", 556 | "is\n", 557 | "also\n", 558 | "the\n", 559 | "use\n", 560 | "and\n", 561 | "use\n", 562 | "for\n", 563 | "practical\n", 564 | "purposes\n", 565 | "of\n", 566 | "\n", 567 | ",\n", 568 | "a\n", 569 | "mathematical\n", 570 | "simulation\n", 571 | "and\n", 572 | "simulation\n", 573 | "that\n", 574 | "uses\n", 575 | "mathematics\n", 576 | "modeling\n", 577 | "and\n", 578 | "simulation\n", 579 | "to\n", 580 | "create\n", 581 | "new\n", 582 | "software\n", 583 | "and\n", 584 | "services\n", 585 | ".\n", 586 | "The\n", 587 | "application\n", 588 | "and\n", 589 | "use\n", 590 | "is\n", 591 | "not\n", 592 | "restricted\n", 593 | "\n", 594 | "and\n", 595 | "is\n", 596 | "only\n", 597 | "used\n" 598 | ] 599 | } 600 | ], 601 | "source": [ 602 | "max_predictions = 50\n", 603 | "mems = None\n", 604 | "for i in range(max_predictions):\n", 605 | " predictions, mems = model(tokens_tensor_1, mems=mems)\n", 606 | " predicted_index_tensor = torch.topk(predictions[0, -1, :],5)[1][1] \n", 607 | " predicted_index = predicted_index_tensor.item()\n", 608 | " predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]\n", 609 | " print(predicted_token)\n", 610 | " \n", 611 | " tokens_tensor_1 = torch.cat((tokens_tensor_1, predicted_index_tensor.reshape(1, 1)), dim=1)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": {}, 618 | "outputs": [], 619 | "source": [] 620 | } 621 | ], 622 | "metadata": { 623 | "kernelspec": { 624 | "display_name": "Python [conda env:transfo-xl] *", 625 | "language": "python", 626 | "name": "conda-env-transfo-xl-py" 627 | }, 628 | "language_info": { 629 | "codemirror_mode": { 630 | "name": "ipython", 631 | "version": 3 632 | }, 633 | "file_extension": ".py", 634 | "mimetype": "text/x-python", 635 | "name": "python", 636 | "nbconvert_exporter": "python", 637 | "pygments_lexer": "ipython3", 638 | "version": "3.7.3" 639 | } 640 | }, 641 | "nbformat": 4, 642 | "nbformat_minor": 2 643 | } 644 | -------------------------------------------------------------------------------- /analysis/discharge_summary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Investigating the discharge summary" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 21, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import os\n", 19 | "import psycopg2\n", 20 | "import sqlalchemy\n", 21 | "from sqlalchemy import create_engine\n", 22 | "import string\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# connect to the mimic database and set the search path to the 'mimiciii' schema\n", 33 | "\n", 34 | "dbschema='mimiciii'\n", 35 | "cnx = create_engine('postgresql+psycopg2://aa5118:mimic@localhost:5432/mimic',\n", 36 | " connect_args={'options': '-csearch_path={}'.format(dbschema)})\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "(59652, 1)" 48 | ] 49 | }, 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "cat = \"'Discharge summary'\"\n", 57 | "df_temp = pd.read_sql_query('''\n", 58 | " SELECT hadm_id FROM noteevents WHERE category = ''' + cat + '''\n", 59 | "''', cnx)\n", 60 | "df_temp.shape" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "59652 Discharge summaries from ~2m notes" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "(59652, 1)" 79 | ] 80 | }, 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "df_temp = pd.read_sql_query('''\n", 88 | " SELECT hadm_id FROM noteevents WHERE category = ''' + cat + ''' AND hadm_id IS NOT NULL\n", 89 | " ORDER BY hadm_id\n", 90 | "''', cnx)\n", 91 | "df_temp.shape" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "These all contain a hospital admission ID (`hadm_id`) - there are no NULLs. Only patients who get admitted to the ICU during their hospital admission will have an hadm_id in this table according to https://github.com/MIT-LCP/mimic-code/issues/237. So this means the discharge summary is only for being discharged from the ICU, not the hospital." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/html": [ 109 | "

\n", 110 | "\n", 123 | "\n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | "
count
052726
\n", 137 | "
" 138 | ], 139 | "text/plain": [ 140 | " count\n", 141 | "0 52726" 142 | ] 143 | }, 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "df_temp = pd.read_sql_query('''\n", 151 | " SELECT COUNT(DISTINCT hadm_id) FROM noteevents WHERE category = ''' + cat + ''' AND hadm_id IS NOT NULL\n", 152 | " --ORDER BY hadm_id\n", 153 | "''', cnx)\n", 154 | "df_temp" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "However, there are only 52726 unique `hadm_id`s in the table. This shows that patients can have multiple distinct ICU stays during a single hospital admission. Indeed there are ~7k duplicate `hadm_id`s. And for each ICU stay, they will get a discharge summary. This clears the picture a litte with deciding how we can utilise the predictive power of the discharge summaries" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 11, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "text/html": [ 172 | "
\n", 173 | "\n", 186 | "\n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | "
categorycounttext_avg_charstext_avg_wordstime_provided
0Case Management9671120162967
1Consult98604076198
2Discharge summary59652962014360
3ECG209051210300
4Echo4579423203260
5General830115602168260
6Nursing2235561790264222172
7Nursing/other822497800132822497
8Nutrition941824303219411
9Pharmacy1032580357102
10Physician1416247140858141048
11Radiology5222791740209522279
12Rehab Services543131204365429
13Respiratory31739136015531703
14Social Work267021603332648
\n", 320 | "
" 321 | ], 322 | "text/plain": [ 323 | " category count text_avg_chars text_avg_words time_provided\n", 324 | "0 Case Management 967 1120 162 967\n", 325 | "1 Consult 98 6040 761 98\n", 326 | "2 Discharge summary 59652 9620 1436 0\n", 327 | "3 ECG 209051 210 30 0\n", 328 | "4 Echo 45794 2320 326 0\n", 329 | "5 General 8301 1560 216 8260\n", 330 | "6 Nursing 223556 1790 264 222172\n", 331 | "7 Nursing/other 822497 800 132 822497\n", 332 | "8 Nutrition 9418 2430 321 9411\n", 333 | "9 Pharmacy 103 2580 357 102\n", 334 | "10 Physician 141624 7140 858 141048\n", 335 | "11 Radiology 522279 1740 209 522279\n", 336 | "12 Rehab Services 5431 3120 436 5429\n", 337 | "13 Respiratory 31739 1360 155 31703\n", 338 | "14 Social Work 2670 2160 333 2648" 339 | ] 340 | }, 341 | "execution_count": 11, 342 | "metadata": {}, 343 | "output_type": "execute_result" 344 | } 345 | ], 346 | "source": [ 347 | "# breakdown of note categories showing the number of notes, average number of characters\n", 348 | "# and the number of notes in each category where the note time was provided\n", 349 | "\n", 350 | "df_summary = pd.read_sql_query('''\n", 351 | " SELECT \n", 352 | " category,\n", 353 | " COUNT(category),\n", 354 | " ROUND(AVG(LENGTH(text)),-1)::integer AS text_avg_chars,\n", 355 | " ROUND(AVG(array_length(regexp_split_to_array(text, '\\s+'), 1)))::integer AS text_avg_words,\n", 356 | " COUNT(charttime) AS time_provided\n", 357 | " FROM noteevents\n", 358 | " GROUP BY category\n", 359 | "''', cnx)\n", 360 | "df_summary" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 14, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "category Case Management ConsultDischarge summaryECGEch...\n", 373 | "count 2083180\n", 374 | "text_avg_chars 43990\n", 375 | "text_avg_words 5996\n", 376 | "time_provided 1766614\n", 377 | "dtype: object\n", 378 | "\n", 379 | "\n", 380 | "count 138878.666667\n", 381 | "text_avg_chars 2932.666667\n", 382 | "text_avg_words 399.733333\n", 383 | "time_provided 117774.266667\n", 384 | "dtype: float64\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "print(df_summary.sum(axis = 0, skipna = True))\n", 390 | "print (\"\\n\")\n", 391 | "print(df_summary.mean(axis = 0, skipna = True))" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 18, 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "Char wgtd avg count: 1826.9048521971217\n", 404 | "Word wgtd avg count: 248.84047705911155\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "print(\"Char wgtd avg count:\", sum(df_summary['text_avg_chars'] * df_summary['count'])/sum(df_summary['count']))\n", 410 | "print(\"Word wgtd avg count:\", sum(df_summary['text_avg_words'] * df_summary['count'])/sum(df_summary['count']))" 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": {}, 416 | "source": [ 417 | "Discharge summaries are the longest note category in the `NOTEEVENTS` table with almost 10k characters." 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 23, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "data": { 427 | "text/html": [ 428 | "
\n", 429 | "\n", 442 | "\n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | "
charttime
\n", 452 | "
" 453 | ], 454 | "text/plain": [ 455 | "Empty DataFrame\n", 456 | "Columns: [charttime]\n", 457 | "Index: []" 458 | ] 459 | }, 460 | "execution_count": 23, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "# confirming that the dataframe output should have 0 rows\n", 467 | "\n", 468 | "df_temp = pd.read_sql_query('''\n", 469 | " SELECT charttime FROM noteevents WHERE category = ''' + cat + ''' AND charttime IS NOT NULL\n", 470 | "''', cnx)\n", 471 | "df_temp" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "Furthermore, they are one of the few note categories which do not provide a timestamp" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "#### Cause of death" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "metadata": {}, 491 | "source": [ 492 | "Below query matches the string \"cause of death\" in the text field and returns the substring occuring after it until the end of the note" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 55, 498 | "metadata": {}, 499 | "outputs": [ 500 | { 501 | "data": { 502 | "text/html": [ 503 | "
\n", 504 | "\n", 517 | "\n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | "
substring
0cause of death was hyperkalemia from acute ren...
1cause of death was likely cardiac arrest.\\nSec...
2cause of death was cardiac arrest secondary to...
3cause of death being cardiopulmonary\\nfailure ...
4Admission Date: [**2131-9-2**] D...
\n", 547 | "
" 548 | ], 549 | "text/plain": [ 550 | " substring\n", 551 | "0 cause of death was hyperkalemia from acute ren...\n", 552 | "1 cause of death was likely cardiac arrest.\\nSec...\n", 553 | "2 cause of death was cardiac arrest secondary to...\n", 554 | "3 cause of death being cardiopulmonary\\nfailure ...\n", 555 | "4 Admission Date: [**2131-9-2**] D..." 556 | ] 557 | }, 558 | "execution_count": 55, 559 | "metadata": {}, 560 | "output_type": "execute_result" 561 | } 562 | ], 563 | "source": [ 564 | "sql = \"\"\"\n", 565 | "SELECT \n", 566 | " SUBSTRING (n.text FROM \n", 567 | " POSITION('cause of death' IN n.text) FOR (LENGTH(n.text) - POSITION('cause of death' IN n.text)))\n", 568 | "FROM noteevents n\n", 569 | "INNER JOIN admissions a\n", 570 | "ON n.hadm_id = a.hadm_id\n", 571 | "WHERE a.hospital_expire_flag = 1\n", 572 | "AND lower(n.text) LIKE '%cause of death%'\n", 573 | "AND category LIKE 'Discharge summary%';\n", 574 | "\"\"\"\n", 575 | "\n", 576 | "df_temp = pd.read_sql_query(sqlalchemy.text(sql), cnx)\n", 577 | "df_temp.head()" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 56, 583 | "metadata": {}, 584 | "outputs": [ 585 | { 586 | "name": "stdout", 587 | "output_type": "stream", 588 | "text": [ 589 | "['cause of death was likely cardiac arrest.\\nSecondary cause sepsis.\\n\\n\\n\\n\\n [**Name6 (MD) **] [**Name8 (MD) **], M.D. [**MD Number(1) 968**]\\n\\nDictated By:[**Last Name (NamePattern1) 2584**]\\n\\nMEDQUIST36\\n\\nD: [**2140-6-9**] 04:23\\nT: [**2140-6-12**] 16:16\\nJOB#: [**Job Number 2585**]']\n" 590 | ] 591 | } 592 | ], 593 | "source": [ 594 | "print([word for word in df_temp.iloc[1]])" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [ 601 | "Maybe we can use the final discharge summary from an ICU visit to predict whether the patient will die in the ICU at their next visit." 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 6, 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "data": { 611 | "text/html": [ 612 | "
\n", 613 | "\n", 626 | "\n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | "
gendercountageunique_subjects
0F2429181.51039516238
1M3111370.14778421162
\n", 653 | "
" 654 | ], 655 | "text/plain": [ 656 | " gender count age unique_subjects\n", 657 | "0 F 24291 81.510395 16238\n", 658 | "1 M 31113 70.147784 21162" 659 | ] 660 | }, 661 | "execution_count": 6, 662 | "metadata": {}, 663 | "output_type": "execute_result" 664 | } 665 | ], 666 | "source": [ 667 | "# breakdown by gender\n", 668 | "\n", 669 | "sql = \"\"\"\n", 670 | " SELECT p.gender, \n", 671 | " COUNT(p.gender), \n", 672 | " AVG(ROUND((cast(n.chartdate as date) - cast(p.dob as date)) / 365.242,0)) AS age,\n", 673 | " COUNT(DISTINCT(n.subject_id)) AS unique_subjects\n", 674 | " FROM patients p \n", 675 | " INNER JOIN noteevents n \n", 676 | " ON p.subject_id = n.subject_id\n", 677 | " WHERE n.category = 'Discharge summary'\n", 678 | " AND ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0) > 14\n", 679 | " GROUP BY p.gender\n", 680 | "\"\"\"\n", 681 | "\n", 682 | "df_summary = pd.read_sql_query(sqlalchemy.text(sql), cnx)\n", 683 | "\n", 684 | "df_summary" 685 | ] 686 | }, 687 | { 688 | "cell_type": "markdown", 689 | "metadata": {}, 690 | "source": [ 691 | "There seems to be a reasonably large difference in the number of men and women as well as their average ages. So we should stratify when we split our dataset to ensure that the training, validation and test sets all have approximately the same average age and gender balance." 692 | ] 693 | }, 694 | { 695 | "cell_type": "markdown", 696 | "metadata": {}, 697 | "source": [ 698 | "### Number of admissions per subject" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 44, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "data": { 708 | "text/html": [ 709 | "
\n", 710 | "\n", 723 | "\n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | "
subject_idnum
01303347
110934
21186132
3506029
41131828
\n", 759 | "
" 760 | ], 761 | "text/plain": [ 762 | " subject_id num\n", 763 | "0 13033 47\n", 764 | "1 109 34\n", 765 | "2 11861 32\n", 766 | "3 5060 29\n", 767 | "4 11318 28" 768 | ] 769 | }, 770 | "execution_count": 44, 771 | "metadata": {}, 772 | "output_type": "execute_result" 773 | } 774 | ], 775 | "source": [ 776 | "sql = \"\"\"\n", 777 | " SELECT \n", 778 | " n.subject_id,\n", 779 | " COUNT(n.subject_id) AS num\n", 780 | " FROM patients p \n", 781 | " INNER JOIN noteevents n \n", 782 | " ON p.subject_id = n.subject_id\n", 783 | " WHERE n.category = 'Discharge summary'\n", 784 | " AND ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0) > 14\n", 785 | " GROUP BY n.subject_id\n", 786 | " ORDER BY num DESC\n", 787 | "\"\"\"\n", 788 | "\n", 789 | "df_summary = pd.read_sql_query(sqlalchemy.text(sql), cnx)\n", 790 | "\n", 791 | "df_summary.head()" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 56, 797 | "metadata": {}, 798 | "outputs": [ 799 | { 800 | "data": { 801 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAZe0lEQVR4nO3de7hcdX3v8feHUMRyiWLQngPEDSaiKSLF3dCjaBGtJzSkIAcEvFQpJaBgi1ZroLZSrcd4qfJ4xIMRMOgD2FREyEkUKAaRPlxyAQRENNJYUygBgSCWez7nj7X2MGz2nqydvdeszOzP63nm2bN+M7PWd83ee77zu6zfT7aJiIgA2KbpACIiYuuRpBARES1JChER0ZKkEBERLUkKERHRsm3TAYzHtGnTPDAw0HQYERE9ZfXq1ffb3nWkx3o6KQwMDLBq1aqmw4iI6CmSfjHaY2k+ioiIliSFiIhoSVKIiIiWJIWIiGhJUoiIiJYkhYiIaOnJpCBpnqRFGzdubDqUiIi+0pNJwfZS2/OnTp3adCgREX2lpy9eG4+BBcvG9Px1C+fWFElExNajJ2sKERFRjySFiIhoSVKIiIiWJIWIiGhJUoiIiJYkhYiIaElSiIiIliSFiIho6cmkkGkuIiLq0ZNXNNteCiwdHBw8oVvHHOsV0JCroCOi9/RkTSEiIuqRpBARES1JChER0ZKkEBERLUkKERHRkqQQEREtSQoREdGSpBARES1JChER0ZKkEBERLUkKERHRkqQQEREtSQoREdGSpBARES1bVVKQtIOk1ZIObTqWiIjJqNakIOk8SRsk3TasfI6kOyWtlbSg7aGPAEvqjCkiIkZXd01hMTCnvUDSFOAs4BBgFnCspFmS3gz8GLi35pgiImIUta68ZvsaSQPDimcDa23fBSDpm8BhwI7ADhSJ4lFJy21vGr5PSfOB+QDTp0+vL/iIiEmoieU4dwN+2ba9HjjA9ikAkt4D3D9SQgCwvQhYBDA4OOh6Q42ImFyaSAoaoaz14W57cfdCiYiIdk2MPloP7NG2vTtw91h2IGmepEUbN26c0MAiIia7JpLCSmCmpD0lbQccA1w2lh3YXmp7/tSpU2sJMCJisqp7SOpFwHXA3pLWSzre9lPAKcDlwB3AEtu31xlHRERUU/foo2NHKV8OLN/S/UqaB8ybMWPGlu4iIiJGsFVd0VxVmo8iIurRk0khIiLqsdmkIOkUSTuX978i6UZJb6o/tIiI6LYqNYX5th+W9BaKC8/eC3ym3rA6y5DUiIh6VEkKQxeWHQJ8zfbqiq+rTfoUIiLqUeXD/RZJy4F5wHcl7UjbFcgREdE/qgxJPQ54DcUkdv8laRpwfL1hRUREE6rUFC4AXgg8CGD7fts31RrVZqRPISKiHlWSwmLgz4CfSvoHSY1fMZY+hYiIemw2Kdj+nu2jKdZB+E9ghaRrJL1LUhOzrEZERE0qjSKS9ELg7cC7gB8BXwFeC3yvvtAiIqLbNvtNX9IS4FXAhcD/sr2+fOgCSY32LURExMSq0vxzDnCl7ecMQ7X9exMfUv8YWLBsTM9ft3BuTZFERFSz2aRg+wpJr5A0C9i+rfzCWiPrILOkRkTUo8rcRx+lWBP5bIqrms8Ejqw5ro4y+igioh5VOpqPBt4I3GP7XcCraWZt54iIqFmVpPCo7aeBpyTtRDEsda96w4qIiCZU+cZ/k6QXAOcBq4CHgTW1RhUREY2o0tF8Ynn3LEmXAzvbTlKIiOhDVTqarxi6b3ut7TXtZU3I3EcREfUYNSlI2q5cce0lknaStHN52x2Y3r0QnyujjyIi6tGp+ehk4IPAi4HbAZXlD1MMT42IiD4zalKw/QXgC5JOtX1mF2OKiIiGVBmS+u/lUFQkLZC0RNJ+NccVERENqJIUzrD9a0mvpViS859I81FERF+qkhSeLn8eCnzZ9sXA8+oLKSIimlLl4rV7JJ0FzAEGJW1HxXUYIiKit1T5cH8b8ANgru0HgWnAglqj2oxcpxARUY8qy3E+YnuJ7Z+U23fb/m79oXWMKdcpRETUIM1AERHRkqQQEREtSQoREdEy6ugjSQ8C7esyG7gfWAGcVnY6R0REH+lUU5gG7Np2ezFwIPBzcvFaRERf6jT30dMjFN8HfFZS1lOIiOhDY+5TkLQtMKWGWCIiomGd+hT+ZITiFwLHAN+pLaKIiGhMp2kujhq2beBXwNm2L60vpIiIaEqnPoV3dTOQsZA0D5g3Y8aMpkOJiOgrnZqPPgWss/2VYeUfAKbZ/pu6gxuN7aXA0sHBwROaiqEOAwuWjen56xbOrSmSiJisOnU0HwZ8dYTyL5aPRUREn+mUFDbZ3jS8sByqqhGeHxERPa5TUnhM0suGF5Zlj9UXUkRENKXT6KOPAcslfQJYXZYNAn8DfKjuwCIiovs6jT5aJmk98NfAh8vi24BjbN/cjeAiIqK7Oi7HafsW4B1diiUiIhrWaUjqJTx7ltRnsX1ELRFFRERjOtUUvtS1KCIiYqvQqU/hqm4GEhERzcvKaxER0ZKkEBERLaMmBUmLy5+ndC2aiIhoVKeawmxJuwEnSNpJ0s7tt24FGBER3dNp9NE5wNXAdOB2nj3fkcvyiIjoI6PWFGx/3vZM4Ou2p9veo+024QlB0islnS3pW5LeO9H7j4iIzdtsR7PtEyTtI+mk8jar6s4lnSdpg6TbhpXPkXSnpLWSFpTHucP2ScDbKOZYioiILttsUpB0MrCEorloOvDPkt5Xcf+LgTnD9jcFOAs4BJgFHDuUaMp1oa8Fco1EREQDqgxJPRGYbft026cDBwAnVdm57WuAB4YVzwbW2r7L9hPANykX7bF9me3XkvmWIiIa0XFCvJKAJ9u2n2R8i+zsBvyybXs9cICkg4AjgOcBy0cNRpoPzAeYPj193RERE6lKUvgGcL2ki8vttwLnj+OYIyUU276aYrRTR7YXAYsABgcHR52wLyIixm6zScH2ZyStAF5P8YF+ku2V4zjmemCPtu3dgbvHsgNJ84B5M2bMGEcYERExXJWaAmUSGE8iaLcSmClpT+A/gGOAt49lB7aXAksHBwdPmKCYIiKCiklhS0m6CDgImFau4vYx2+eWU2dcDkwBzrN9e51x9KuBBcvG9Px1C+fWFElE9Itak4LtY0cpX06HzuSIiGhGxyGpkqZIurxbwVQlaZ6kRRs3bmw6lIiIvtIxKdh+Gnhia5sAz/ZS2/OnTp3adCgREX2lSvPRI8Atkq4AfjNUaPuDtUUVERGNqJIU/qW8RUREn6tyncK5krYDptte24WYNivXKURE1KPKhHhzgVuBK8vt/SRdUndgnaRPISKiHlUmxPs4xSR4DwHYvhnIV/SIiD5UJSk8afuhYWWZcygiog9VSQp3SHobsI2kPSWdCVxfc1wd5TqFiIh6VEkKpwCvATYBlwCPA6fWGdTmpE8hIqIeVUYf/Qb4iKS/Lzb9aP1hRUREE6qMPtpf0k3AT4GfSVotaf/6Q4uIiG6r0nz0NeCDtne3vTvwV2VZRET0mSpJ4Te2VwxtlCukPVJbRBWkozkioh6jJgVJ+0raF7hB0lmSDpT0OklfBFaM9rpuSEdzREQ9OnU0nzVse9+2+7lOISKiD42aFGy/vpuBRERE8zY7JLVcS+GdwED78zN1dkRE/6kydfZyYA3FpHib6g0nIiKaVCUp/Lbtv6g9kjHI1NkREfWoMiT1QknHSdpV0s5Dt9oj6yCjjyIi6lF1Oc4zgU/wzKgjA9PrCirqMbBg2Ziev27h3JoiiYitVZWk8GFgpu0NdQcTERHNqtJ89GPg4boDiYiI5lWpKTwB3CTp+xTTZgMZkhoR0Y+qDkldXncgERHRvCrrKZzbjUAiIqJ5Va5o/hkjzHVk++W1RFRBrlPojoxWiph8qjQfHdh2f3vgKKDRCwRsLwWWDg4OntBkHBER/aZK89G9w4o+J+namuKJiIgGVWk+ap8yextgkIZrChERUY8qzUft6yo8BawDjq4lmoiIaFSV5qOsqxARMUlUaT7aDjic566n8L/rCysiIppQpfnoEuAxYDXwdL3hREREk6okhZfa3qf2SCIionFVJsS7XtKs2iOJiIjGVakpHEAxId5aignxBNj2/rVGFhERXVclKRxeexQREbFVqDIk9efdCGQsMvdRREQ9qvQpbHWyRnNERD16MilEREQ9khQiIqJl1D4FSQ8ywjoKPDP6aJfaooqIiEZ06mie1rUoIiJiqzBqUrD9rCktJO1CscjOkLvrCioiIpqx2T4FSXMl/RRYD9xQ/vx+3YFFRET3Velo/iTwOuBO23sA/xO4us6gIiKiGVWuaH7K9n2StpEk21dK+mTtkUXPGViwbMyvWbdwbg2RRMSWqpIUNkraAbgW+LqkDcCmesOKiIgmVGk+OpxiPYVTKZqN/gM4tMaYIiKiIVWSwmm2n7b9pO1zbX8e+GDdgUVERPdVSQpzRihLQ3BERB/qdEXzicBJwMslrWl7aCdgVd2BRURE93XqaF4CXAV8CljQVv5r2xtqjSoiIhoxavOR7Qdtr7V9FPB84I/K2651BSPpcElflXSppLfUdZyIiBhZlSuaT6aoNUwvb0skva/qASSdJ2mDpNuGlc+RdKektZIWANj+ju0TgPcAR4/hPCIiYgJU6Wg+EZht+3Tbp1Os2XzSGI6xmGGd1ZKmAGcBhwCzgGMlzWp7ykfLxyMioouqJAUBT7ZtP1mWVWL7GuCBYcWzgbW277L9BPBN4DAVPg181/aa4fsCkDRf0ipJq+67776qYURERAWdRh9ta/sp4BvA9ZIuLh96K3D+OI+7G/DLtu31FDWQ9wNvBqZKmmH77OEvtL0IWAQwODg40noPERGxhTqNProR2N/2ZyStAF5PUUM4yfbKcR53pJqGbX8R+OI49x0REVuoU1JofXCXSWC8iaDdemCPtu3dGcP6DJLmAfNmzJgxgSFFRESnpLCrpFGnsyinu9hSK4GZkvakmEvpGODtVV9seymwdHBw8IRxxBAREcN06mieAuxIcQXzSLdKJF0EXAfsLWm9pOPLvopTgMuBO4Altm/fslOIiIiJ0qmmcI/tj4/3ALaPHaV8ObB8S/aZ5qOIiHp0qilUHnbabbaX2p4/derUpkOJiOgrnZLCm7oWRUREbBU6zX00/IKziIjoc1WuaN7qSJonadHGjRubDiUioq/0ZFJIn0JERD06jT6KqN3AgmVjev66hVn0L6JOPVlTiIiIevRkUkifQkREPXoyKaRPISKiHulTiJ4yGfsgJuM5R3N6sqYQERH1SFKIiIiWnkwK6WiOiKhHTyaFdDRHRNQjHc0REROslwcH9GRNISIi6pGkEBERLWk+ir7Wy9X4iCb0ZE0ho48iIurRk0kho48iIurRk0khIiLqkaQQEREtSQoREdGS0UcRMSYZ0dXfUlOIiIiW1BQiJrmxfvOP/taTNYVcpxARUY+eTAq5TiEioh49mRQiIqIe6VOI6LK04cfWLDWFiIhoSVKIiIiWJIWIiGhJn0LEOKR/IPpNkkJEm3zIx2SX5qOIiGhJUoiIiJYkhYiIaOnJpJC5jyIi6tGTSSFzH0VE1KMnk0JERNQjSSEiIlqSFCIioiVJISIiWmS76Ri2mKT7gF9s4cunAfdPYDi9IOc8OeScJ4fxnPNLbe860gM9nRTGQ9Iq24NNx9FNOefJIec8OdR1zmk+ioiIliSFiIhomcxJYVHTATQg5zw55Jwnh1rOedL2KURExHNN5ppCREQMk6QQEREtkzIpSJoj6U5JayUtaDqeOkg6T9IGSbe1le0i6UpJPyt/vrDJGCeSpD0krZB0h6TbJf1lWd7P57y9pBsl3VKe89+X5XtKuqE853+StF3TsU40SVMk3STp/5XbfX3OktZJulXSzZJWlWW1/G1PuqQgaQpwFnAIMAs4VtKsZqOqxWJgzrCyBcBVtmcCV5Xb/eIp4K9svxL4A+Dk8vfaz+f8OHCw7VcD+wFzJP0B8GngC+U5Pwgc32CMdflL4I627clwzm+0vV/btQm1/G1PuqQAzAbW2r7L9hPAN4HDGo5pwtm+BnhgWPFhwPnl/fOBw7saVI1s32N7TXn/1xQfGLvR3+ds24+Um79V3gwcDHyrLO+rcwaQtDswFzin3BZ9fs6jqOVvezImhd2AX7Ztry/LJoOX2L4Hig9R4MUNx1MLSQPA7wE30OfnXDaj3AxsAK4Efg48ZPup8in9+Pd9JvDXwKZy+0X0/zkbuELSaknzy7Ja/ra3nYid9BiNUJZxuX1C0o7AxcCpth8uvkT2L9tPA/tJegFwCfDKkZ7W3ajqI+lQYIPt1ZIOGioe4al9c86l19m+W9KLgSsl/aSuA03GmsJ6YI+27d2BuxuKpdvulfTfAMqfGxqOZ0JJ+i2KhHCB7W+XxX19zkNsPwRcTdGf8gJJQ1/4+u3v+3XAn0haR9H0ezBFzaGfzxnbd5c/N1Ak/9nU9Lc9GZPCSmBmOVphO+AY4LKGY+qWy4B3l/ffDVzaYCwTqmxXPhe4w/bn2x7q53PetawhIOn5wJsp+lJWAEeWT+urc7Z9mu3dbQ9Q/O9+3/Y76ONzlrSDpJ2G7gNvAW6jpr/tSXlFs6Q/pvh2MQU4z/YnGw5pwkm6CDiIYnrde4GPAd8BlgDTgX8HjrI9vDO6J0k6EPghcCvPtDWfTtGv0K/nvC9FB+MUii94S2x/XNJeFN+idwFuAt5p+/HmIq1H2Xz0IduH9vM5l+d2Sbm5LXCh7U9KehE1/G1PyqQQEREjm4zNRxERMYokhYiIaElSiIiIliSFiIhoSVKIiIiWJIUYlSRL+se27Q9JOmOC9r1Y0pGbf+a4j3NUOXPqimHlA8NmkJ0t6Zpy9tyfSDpH0m9LOkPSh4a9dp2kaRWP/x5J/31izmbrIumgoVlKR3hs+dA1FFuwz9eOP7rYUkkK0cnjwBFVPwC7pZzptqrjgffZfmOH/b0E+GfgI7b3ppgq4nvATuMKtPAeYKtMCirU8hlg+4/Lq6zH6iAgSaFBSQrRyVMU68B+YPgDw7/pS3qk/HmQpB9IWiLpp5IWSnpHOe//rZJe1rabN0v6Yfm8Q8vXT5H0WUkrJf1I0olt+10h6UKKC9SGx3Nsuf/bJH26LPs74EDgbEmf7XCeJwPn274OWrOPfsv2vVXfqDLuxeXxb5X0gfL9GQQuUDEP/vMl/V15brdJWlR+ML9M0pq2fc2UtLq8v1DSj8v34nMjHPcMSd+Q9H0V8+qf0PbYh9vex6G1FgbKmtOXgTU8e8qXEY832u+6tLOkS8rXnD2UZNprU5LeWf7+b5b0laGkrmJdkzUq1oO4SsVEhicBHyif+/qq739MnMk4IV6MzVnAjyR9ZgyveTXFt+0HgLuAc2zPVrHwzfuBU8vnDQB/CLwMWCFpBvCnwEbbvy/pecC/SrqifP5sYB/b/9Z+sLJ55tPAayjm0r9C0uHl1b0HU1z1uqpDvPvwzBTEW2o/YDfb+5QxvcD2Q5JOaT++pC/Z/nh5/xvAobaXStooaT/bNwPHAYsl7QK8FXiFbXdojtmXYs6jHYCbJC0rz2kmxXsm4DJJb6C48nVv4Djb72vfyRiO1242xbokv6CoXR3BM1NYI+mVwNEUE7o9WSajd0j6LvBV4A22/03SLrYfkHQ28Ijt5yTA6I7UFKIj2w8DXwf+YgwvW1mub/A4xVTOQx/qt1IkgiFLbG+y/TOK5PEKinld/lTFdNA3UEyLPLN8/o3DE0Lp94Grbd9XTp98AfCGMcTbyWiX/A8vvwvYS9L/kTQHeHiU171RxQpht1JM5va7Zfk5wHHlt+ijgQvLfTwGnCPpCOC/RtnnpbYftX0/xRxAsynex7dQTPmwhuK9HXoff2H7+hH2U/V47W4s1yZ5GriIombW7k0UyXpl+Tt9E7AXRRK7Zuj32S9Tj/SDJIWo4kyKtvkd2sqeovz7kSSgffnD9jlnNrVtb+LZtdPhH6ym+Fb7/nKFqf1s72l7KKn8ZpT4xjs/9u0UH1wj+RUwfJnDnYBntZfbfpCihnQ1RXPUOc8JUtoe+DJwpO1XUXxT3r58+GKK1QAPBVbb/lWZ4GaXjx1O8U18JKO9j59qex9n2D63fHzE97HD8Tr9rkc69rNOm6JpbiiOvW2fUZZnjp2tUJJCbFb5LW4Jz17icB3PfJAeRrHq11gdJWmbsp9hL+BO4HLgvSqmwUbSy1XMDNnJDcAfSppWftM+FvjBGOL4EvBuSQcMFZTt4L8DXEMxVfPQLJVHALeU34xpe/40YBvbFwN/C+xfPvRrnumwHkoA96tY96HVTm/7sfLc/y/wtXKfOwJTbS+naHLbb5T4D1OxXvOLKDpqV5b7+rNyH0jaTcVc/KPqcLx1jP67nq1ixuFtKGo41w7b7VXAkUPHVrGu8EuB6yh+Z3sOlZfPb3+/ogHpU4iq/hE4pW37q8Clkm6k+Mcf7Vt8J3dSfHi/BDjJ9mOSzqFoYlpTfiu9j80sM2j7HkmnUTSdCFhuu/I0wrbvlXQM8Lnyw2sTRTL4tu3/lPQl4FpJppiz/s9H2M1uwNf0zGie08qfiyk6uh8F/gfF+3YrxQftymH7uICiTX6oZrQTxXu8fXlez+nwL90ILKOYLfMT5dz7d5ft+dcVbyOPAO8Enh5lH52O1+l3fR2wEHgVxXt2Sdtjtv1jSR+l6OfZBngSONn29SpWEPt2Wb4B+CNgKfAtSYdR1Bh/2CHeqEFmSY3YSqi4HmKq7b8dw2vOYCvrmC1raxuA37H9ZNPxxNikphCxFZB0CcUorIObjmUC3E4x4iwJoQelphARES3paI6IiJYkhYiIaElSiIiIliSFiIhoSVKIiIiW/w9FYw5d92CbrAAAAABJRU5ErkJggg==\n", 802 | "text/plain": [ 803 | "
" 804 | ] 805 | }, 806 | "metadata": { 807 | "needs_background": "light" 808 | }, 809 | "output_type": "display_data" 810 | } 811 | ], 812 | "source": [ 813 | "plt.hist(df_summary['num'], bins=25, range=(0, 50), weights=df_summary['num'])\n", 814 | "plt.ylabel('Total number of ICU stays')\n", 815 | "plt.xlabel('Number of ICU stays per subject')\n", 816 | "plt.yscale('log')\n", 817 | "plt.savefig('admissions.pdf')" 818 | ] 819 | } 820 | ], 821 | "metadata": { 822 | "kernelspec": { 823 | "display_name": "Python 3", 824 | "language": "python", 825 | "name": "python3" 826 | }, 827 | "language_info": { 828 | "codemirror_mode": { 829 | "name": "ipython", 830 | "version": 3 831 | }, 832 | "file_extension": ".py", 833 | "mimetype": "text/x-python", 834 | "name": "python", 835 | "nbconvert_exporter": "python", 836 | "pygments_lexer": "ipython3", 837 | "version": "3.7.3" 838 | } 839 | }, 840 | "nbformat": 4, 841 | "nbformat_minor": 2 842 | } 843 | --------------------------------------------------------------------------------