├── __init__.py ├── preprocess ├── __init__.py ├── constants.py ├── split_dataset.py ├── utils.py ├── README.md ├── create_dataset_statistics.ipynb └── regular_expressions.py ├── notebooks ├── README.md ├── check_cleaned_and_improved.ipynb ├── determine_token_length.ipynb ├── evaluate_gpt_4_performance.ipynb ├── evalute_gpt_4_prompt.ipynb ├── average_test_runs.ipynb └── build_annotation_datasets.ipynb ├── labeling ├── Instructions-Annotation.docx ├── select_dataset_subset_with_chars.py ├── convert_jsonl_medtator.py ├── remove_json_key_from_bioc.py ├── README.md ├── create_revised_dataset.py └── analyse_labelings.ipynb ├── scripts ├── README.md ├── parameter_tuning_led_wrapper.sh ├── parameter_tuning_llama_wrapper.sh ├── parameter_tuning_led.sh └── parameter_tuning_llama.sh ├── hallucination_detection ├── README.md ├── utils.py ├── create_medcat_entities_and_sapbert_embeddings.ipynb └── convert_bioc_to_json_datasets.ipynb ├── LICENSE ├── gpt-4 ├── README.md ├── run_all.sh ├── run_summarization.py ├── create_hallucination_icl.ipynb └── summarization.ipynb ├── .gitignore ├── README.md └── summarization └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | This folder contains Jupyter notebooks for different experiments, helper tasks, and analyses. -------------------------------------------------------------------------------- /labeling/Instructions-Annotation.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanhgm/patient_summaries_with_llms/HEAD/labeling/Instructions-Annotation.docx -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Parameter Tuning Scripts 2 | 3 | The scripts in this folder were used for the parameter tuning of the LED and Llama 2 models for the summarization task. -------------------------------------------------------------------------------- /hallucination_detection/README.md: -------------------------------------------------------------------------------- 1 | # Automatic Hallucination Detection 2 | 3 | This folder contains the code to run automatic hallucination detection based on medical entities extracted with MedCAT and SapBERT embeddings. 4 | 5 | ## Notebooks 6 | 7 | * [convert_bioc_to_json_datasets.ipynb](convert_bioc_to_json_datasets.ipynb): Convert the BIOC format exported with MedTator to JSON dataset. 8 | * [create_medcat_entities_and_sapbert_embeddings.ipynb](create_medcat_entities_and_sapbert_embeddings.ipynb): Create MedCAT entities and SapBERT embeddings for the hallucination detection. 9 | * [evaluate_hallucination_detection_entities_embeddings.ipynb](evaluate_hallucination_detection_entities_embeddings.ipynb): Evaluate the hallucination detection based on the entities and embeddings. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stefan Hegselmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /labeling/select_dataset_subset_with_chars.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | import pandas as pd 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--input_file", 10 | type=str, 11 | default=None, 12 | ) 13 | parser.add_argument( 14 | "--output_dir", 15 | type=str, 16 | default=None, 17 | ) 18 | parser.add_argument( 19 | "--text_max_chars", 20 | type=int, 21 | default=99999, 22 | ) 23 | parser.add_argument( 24 | "--summary_min_chars", 25 | type=int, 26 | default=0, 27 | ) 28 | args = parser.parse_args() 29 | return args 30 | 31 | def main(): 32 | # Read input jsonl file 33 | args = parse_args() 34 | with open(args.input_file, "r") as f: 35 | data = f.readlines() 36 | data = [json.loads(d) for d in data] 37 | 38 | # Filter out texts with less than text_max_chars characters via list comprehension 39 | data = [d for d in data if len(d['text']) <= args.text_max_chars] 40 | data = [d for d in data if len(d['summary']) >= args.summary_min_chars] 41 | 42 | # Write output as jsonl file using pandas 43 | file_name = Path(args.input_file).name 44 | file_name = file_name[:file_name.rfind('.')] 45 | output_file = Path(args.output_dir) / (file_name + f"_{args.text_max_chars}_{args.summary_min_chars}_chars.json") 46 | print(f"Writing {len(data)} text-summary pairs to {output_file}") 47 | df = pd.DataFrame(data) 48 | df.to_json(output_file, orient='records', lines=True) 49 | 50 | if __name__ == "__main__": 51 | main() -------------------------------------------------------------------------------- /scripts/parameter_tuning_led_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script wraps the parameter tuning script and submits it to the cluster. 3 | # * Only iterate over all tuning parameters 4 | # * Only create experiment folder if not run already 5 | # -> Rest is done in the main script 6 | 7 | # General 8 | model_name_dir="led-large-16384" # led-base-16384 led-large-16384 9 | run_dir="mimic-iv-note-di-bhc_led-large-16384_4000_600_chars_100_valid" # 4000_600_chars_100_valid long_data_100_valid 10 | 11 | # Cluster 12 | project="/home/s/s_hegs02/scratch/mimic-iv-note-di-bhc" 13 | code="/home/s/s_hegs02/patient_summaries_with_llms" 14 | # Local 15 | # project="/home/s_hegs02/mimic-iv-note-di-bhc" 16 | # code="/home/s_hegs02/patient_summaries_with_llms" 17 | data_path="${project}/dataset" 18 | output_path="${project}/models/${model_name_dir}/${run_dir}" 19 | 20 | # Parameters 21 | for dropout in 0.05 0.1 0.2; do # 0.05 0.1 0.2 22 | for learning_rate in "5e-4" "1e-5" "5e-5" "1e-6" "5e-6"; do # "5e-4" "1e-5" "5e-5" "1e-6" "5e-6" 23 | # Define run folder 24 | # folder_name="debug" 25 | folder_name="dropout_${dropout}_learning_rate_${learning_rate}" 26 | experiment_path="${output_path}/${folder_name}" 27 | 28 | if [ ! -d "$experiment_path" ]; then 29 | echo "Starting experiment: $experiment_path" 30 | mkdir -p "$experiment_path" 31 | # Cluster 32 | sbatch ${code}/scripts/parameter_tuning_led.sh ${experiment_path} ${dropout} ${learning_rate} 33 | # Local 34 | # bash ${code}/scripts/parameter_tuning_led.sh ${experiment_path} ${dropout} ${learning_rate} 35 | else 36 | echo "X Experiment already exists: $experiment_path" 37 | fi 38 | done done 39 | -------------------------------------------------------------------------------- /preprocess/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Special character replacement 4 | SPECIAL_CHARS_MAPPING_TO_ASCII = { 5 | u'\u0091': '\'', 6 | u'\u0092': '\'', 7 | u'\u0093': '\"', 8 | u'\u0094': '-', 9 | u'\u0096': '-', 10 | u'\u0097': '-', 11 | '·': '-', 12 | '¨': '-', 13 | u'\u0095': '\n', 14 | } 15 | 16 | ENCODE_STRINGS_DURING_PREPROCESSING = { 17 | # Use this to encode Dr. as =D= to prevent it from being split into Dr and . 18 | 'Dr.': '@D@' 19 | } 20 | 21 | # Service mappings 22 | # Map all services to uppercase long forms 23 | SERVICE_MAPPING = { 24 | 'MED': 'MEDICINE', 25 | 'VSU': 'SURGERY', 26 | 'OBS': 'OBSTETRICS/GYNECOLOGY', 27 | 'ORT': 'ORTHOPAEDICS', 28 | 'General Surgery': 'SURGERY', 29 | 'Biologic': 'BIOLOGIC', 30 | 'Biologic Service': 'BIOLOGIC', 31 | 'GYN': 'OBSTETRICS/GYNECOLOGY', 32 | 'Biologics': 'BIOLOGIC', 33 | 'Neurology': 'NEUROLOGY', 34 | 'ACS': 'SURGERY', 35 | 'Biologics Service': 'BIOLOGIC', 36 | 'NEURO': 'NEUROLOGY', 37 | 'PSU': 'SURGERY', 38 | 'TRA': 'SURGERY', 39 | 'OP': 'SURGERY', 40 | 'Neuromedicine': 'NEUROLOGY', 41 | 'ENT': 'OTOLARYNGOLOGY', 42 | 'OBSTERTRIC/GYNECOLOGY': 'OBSTETRICS/GYNECOLOGY', 43 | 'OB service': 'OBSTETRICS/GYNECOLOGY', 44 | 'Vascular Service': 'SURGERY', 45 | 'OB-GYN': 'OBSTETRICS/GYNECOLOGY', 46 | 'Vascular': 'SURGERY', 47 | 'Surgical': 'SURGERY', 48 | 'Ob-GYN': 'OBSTETRICS/GYNECOLOGY', 49 | 'General surgery': 'SURGERY', 50 | 'TRANSPLANT ': 'SURGERY', 51 | 'ACS Service': 'SURGERY', 52 | 'Thoracic Surgery Service': 'SURGERY', 53 | 'Otolaryngology': 'OTOLARYNGOLOGY', 54 | 'GU': 'UROLOGY', 55 | 'CSU': 'SURGERY', 56 | 'NME': 'NEUROLOGY', 57 | 'BIOLOGICS': 'BIOLOGIC', 58 | 'GENERAL SURGERY': 'SURGERY', 59 | 'SURGICAL ONCOLOGY': 'SURGERY', 60 | 'Surgical Oncology': 'SURGERY', 61 | '': 'UNKNOWN' 62 | } 63 | 64 | -------------------------------------------------------------------------------- /hallucination_detection/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from bioc import biocxml 4 | import re 5 | 6 | # Label defintions 7 | labels = { 8 | 'c': 'condition_unsupported', 9 | 'p': 'procedure_unsupported', 10 | 'm': 'medication_unsupported', 11 | 't': 'time_unsupported', 12 | 'l': 'location_unsupported', 13 | 'n': 'number_unsupported', 14 | 'na': 'name_unsupported', 15 | 'w': 'word_unsupported', 16 | 'o': 'other_unsupported', 17 | 'co': 'contradicted_fact', 18 | 'i': 'incorrect_fact' 19 | } 20 | 21 | # Read dataset 22 | def read_bioc(path): 23 | with open(path, 'r') as fp: 24 | return biocxml.load(fp) 25 | 26 | # Create dict of document ids and their annotations 27 | def extract_id(document_name): 28 | if re.search(r'\d+_(qualitative|hallucination)', document_name): 29 | return int(document_name.split('_')[0]) 30 | elif re.search(r'train_4000_600', document_name): 31 | return int(document_name.split('_')[-1].split('.')[0]) 32 | else: 33 | raise ValueError('Document name does not contain id') 34 | 35 | def parse_label(annotation): 36 | # Create a dict of start index, end index, length, label, text 37 | start = annotation.locations[0].offset 38 | end = start + annotation.locations[0].length 39 | length = annotation.locations[0].length 40 | # Get all character before digit of annotation id 41 | label_prefix = str(re.findall(r'[^\d]+', annotation.id)[0]) 42 | label = labels[label_prefix.lower()] 43 | text = annotation.text 44 | return {'start': start, 'end': end, 'length': length, 'label': label, 'text': text} 45 | 46 | # Sort lists of dict by dict key start 47 | def sort_by_start(l): 48 | return sorted(l, key=lambda k: k['start']) 49 | 50 | def parse_text_labels(labeling): 51 | result = {} 52 | for document in labeling.documents: 53 | id = extract_id(document.id) 54 | labels = sort_by_start([parse_label(a) for a in document.passages[0].annotations]) 55 | text = document.passages[0].text 56 | result[id] = {'labels': labels, 'text': text} 57 | return result 58 | -------------------------------------------------------------------------------- /notebooks/check_cleaned_and_improved.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Verify annotations of cleaned and improved data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 13, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import re\n", 17 | "import json" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 16, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Read file\n", 27 | "path = '/home/s_hegs02/mimic-iv-note-di-bhc/dataset/hallucination_summaries_cleaned_improved_test_100.json'\n", 28 | "# path = '/home/s_hegs02/mimic-iv-note-di-bhc/dataset/hallucination_summaries_cleaned_improved_valid_10.json'\n", 29 | "# path = '/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_3_in-context.json'\n", 30 | "# path = '/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_3_test.json'\n", 31 | "# path = '/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_6_in-context.json'\n", 32 | "\n", 33 | "# Read jsonl file, each line is a dict with 'text' and 'summary'\n", 34 | "with open(path, 'r') as f:\n", 35 | " data = [json.loads(line) for line in f.readlines()]\n", 36 | " " 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 17, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Check data\n", 46 | "\n", 47 | "# Check that no summary contains '__'\n", 48 | "for i, d in enumerate(data):\n", 49 | " if '__' in d['summary']:\n", 50 | " print(i, d['summary'])\n", 51 | " " 52 | ] 53 | } 54 | ], 55 | "metadata": { 56 | "kernelspec": { 57 | "display_name": "avs_gen", 58 | "language": "python", 59 | "name": "python3" 60 | }, 61 | "language_info": { 62 | "codemirror_mode": { 63 | "name": "ipython", 64 | "version": 3 65 | }, 66 | "file_extension": ".py", 67 | "mimetype": "text/x-python", 68 | "name": "python", 69 | "nbconvert_exporter": "python", 70 | "pygments_lexer": "ipython3", 71 | "version": "3.9.18" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 2 76 | } 77 | -------------------------------------------------------------------------------- /scripts/parameter_tuning_llama_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script wraps the parameter tuning script and submits it to the cluster. 3 | # * Only iterate over all tuning parameters 4 | # * Only create experiment folder if not run already 5 | # -> Rest is done in the main script 6 | 7 | # General 8 | model_name_dir="Llama-2-70b-hf" # Llama-2-7b-hf Llama-7-70b-hf 9 | run_dir="mimic-iv-note-di-bhc_Llama-2-70b-hf_4000_600_chars_100_valid" # 4000_600_chars_100_valid long_data_100_valid 10 | 11 | # Cluster 12 | project="/home/s/s_hegs02/scratch/mimic-iv-note-di-bhc" 13 | code="/home/s/s_hegs02/patient_summaries_with_llms" 14 | # Local 15 | # project="/home/s_hegs02/mimic-iv-note-di-bhc" 16 | # code="/home/s_hegs02/patient_summaries_with_llms" 17 | data_path="${project}/dataset" 18 | output_path="${project}/models/${model_name_dir}/${run_dir}" 19 | 20 | # Parameters 21 | for lora_rank in 8 32; do # 8 32, if possible might also check 4, 16 22 | for lora_alpha in 8 32; do # 8 32 23 | for lora_dropout in 0.05 0.1; do # 0.05 0.1 24 | for num_target_modules in 2 4; do # 2 4 25 | for learning_rate in "2e-5" "2e-4"; do # First round only "2e-5" "2e-4", full: "5e-6" "2e-5" "5e-4" "2e-4" "5e-5" 26 | # Define run folder 27 | # folder_name="debug" 28 | folder_name="lora_rank_${lora_rank}_lora_alpha_${lora_alpha}_lora_dropout_${lora_dropout}_num_target_modules_${num_target_modules}_learning_rate_${learning_rate}" 29 | experiment_path="${output_path}/${folder_name}" 30 | 31 | if [ ! -d "$experiment_path" ]; then 32 | echo "Starting experiment: $experiment_path" 33 | mkdir -p "$experiment_path" 34 | # Cluster 35 | sbatch ${code}/scripts/parameter_tuning_llama.sh ${experiment_path} ${lora_rank} ${lora_alpha} ${lora_dropout} ${num_target_modules} ${learning_rate} 36 | # Local 37 | # bash ${code}/scripts/parameter_tuning_llama.sh ${experiment_path} ${lora_rank} ${lora_alpha} ${lora_dropout} ${num_target_modules} ${learning_rate} 38 | else 39 | echo "X Experiment already exists: $experiment_path" 40 | fi 41 | done done done done done 42 | -------------------------------------------------------------------------------- /gpt-4/README.md: -------------------------------------------------------------------------------- 1 | # Run summarization with GPT-4 2 | 3 | We used the `1106-Preview` version of GPT-4. 4 | 5 | ## Setup 6 | 7 | ```bash 8 | pip install openai==0.27.0 guidance==0.0.64 9 | ``` 10 | 11 | ## Run Summarization 12 | 13 | ```bash 14 | python run_summarization.py --task_id 4 --model_name gpt-4 --n_shot 3 --verbose 15 | 16 | # or 17 | bash run_all.sh 18 | ``` 19 | 20 | ## Preparing Data for GPT-4 Experiments 21 | 22 | Following commands were used to generate the data: 23 | 24 | ```bash 25 | tail -n 10 train.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/prompt_train.json 26 | tail -n 10 valid.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/prompt_valid.json 27 | 28 | tail -n 10 train.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_1_in-context.json 29 | tail -n 100 test.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_1_test.json 30 | tail -n 10 train_4000_600_chars.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_2_in-context.json 31 | tail -n 100 test_4000_600_chars.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_2_test.json 32 | 33 | # Use dataset of 10 cleaned and improved (validation of annotation) examples for exp_3_in-context.json 34 | cp hallucination_summaries_cleaned_improved_valid_10.json ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_3_in-context.json 35 | # Used dataset of 100 cleaned and improved examples for exp_3_test.json 36 | cp hallucination_summaries_cleaned_improved_test_100.json ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_3_test.json 37 | 38 | # Also use the 10 validation examples here for in-context examples 39 | cp hallucination_summaries_original_valid_10.json ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_4_in-context.json 40 | cp hallucination_summaries_cleaned_valid_10.json ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_5_in-context.json 41 | cp hallucination_summaries_cleaned_improved_valid_10.json ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_6_in-context.json 42 | tail -n 50 test_4000_600_chars.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_4_test.json 43 | tail -n 50 test_4000_600_chars.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_5_test.json 44 | tail -n 50 test_4000_600_chars.json > ~/patient_summaries_with_llms/gpt-4/summarization_data/exp_6_test.json 45 | ``` -------------------------------------------------------------------------------- /preprocess/split_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import pandas as pd 4 | import pickle 5 | import math 6 | from pathlib import Path 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--input_file", 13 | type=str, 14 | default=None, 15 | ) 16 | parser.add_argument( 17 | "--output_dir", 18 | type=str, 19 | default=None, 20 | ) 21 | parser.add_argument( 22 | "--prefix", 23 | type=str, 24 | default='', 25 | ) 26 | parser.add_argument( 27 | "--hospital_course_column", 28 | type=str, 29 | ) 30 | parser.add_argument( 31 | "--summary_column", 32 | type=str, 33 | ) 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def main(): 39 | args = parse_args() 40 | try: 41 | mimic_df = pd.read_pickle(args.input_file) 42 | except pickle.UnpicklingError: 43 | mimic_df = pd.read_csv(args.input_file) 44 | except: 45 | raise ValueError("Could not read input file. Please provide a valid pickle or csv file.") 46 | print(f"Found total of {len(mimic_df)} texts") 47 | 48 | # Rename columns and shuffle 49 | mimic_df = mimic_df[[args.hospital_course_column, args.summary_column]] 50 | mimic_df.rename(columns={args.hospital_course_column: 'text', args.summary_column: 'summary'}, inplace = True) 51 | mimic_df = mimic_df.sample(frac=1, random_state=42).reset_index(drop=True) 52 | 53 | # Split into train, valid, test 54 | num_train = math.floor(len(mimic_df) * 0.8) 55 | num_valid = math.floor(len(mimic_df) * 0.1) 56 | num_test = math.floor(len(mimic_df) * 0.1) 57 | 58 | all_out = Path(args.output_dir) / (args.prefix + 'all.json') 59 | train_out = Path(args.output_dir) / (args.prefix + 'train.json') 60 | valid_out = Path(args.output_dir) / (args.prefix + 'valid.json') 61 | test_out = Path(args.output_dir) / (args.prefix + 'test.json') 62 | 63 | mimic_df.to_json(all_out, orient='records', lines=True) 64 | mimic_df.iloc[0:num_train].to_json(train_out, orient='records', lines=True) 65 | mimic_df.iloc[num_train:num_train+num_valid].to_json(valid_out, orient='records', lines=True) 66 | mimic_df.iloc[num_train+num_valid:].to_json(test_out, orient='records', lines=True) 67 | 68 | print(f" Wrote {num_train} train, {num_valid} valid, and {num_test} test examples to {args.output_dir}") 69 | 70 | 71 | if __name__ == '__main__': 72 | main() -------------------------------------------------------------------------------- /scripts/parameter_tuning_led.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=parameter_tuning 4 | #SBATCH --output=/home/s/s_hegs02/logs/parameter_tuning-%J.log 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=1 7 | 8 | # GPU - LED-base 9 | # --time=0-13:00:00 10 | # --partition=gpuv100,gpu3090 11 | # --cpus-per-task=2 12 | # --gres=gpu:1 13 | # --mem=50GB 14 | 15 | # GPU - LED-large 16 | #SBATCH --time=0-40:00:00 17 | #SBATCH --partition=gpu3090,gputitanrtx,gpua100 18 | #SBATCH --cpus-per-task=2 19 | #SBATCH --gres=gpu:1 20 | #SBATCH --mem=30GB 21 | 22 | echo "------------------------------------------------------------" 23 | echo "SLURM JOB ID: $SLURM_JOBID" 24 | echo "Running on nodes: $SLURM_NODELIST" 25 | echo "------------------------------------------------------------" 26 | 27 | # Load the JupyterLab module 28 | ml palma/2022a 29 | ml GCCcore/11.3.0 30 | # ml CUDA/11.7.0 31 | 32 | # Load conda 33 | source /home/s/s_hegs02/.bashrc-slurm 34 | 35 | # Load environment 36 | conda activate ps_llm 37 | 38 | # Change path 39 | # cd /home/s/s_hegs02/patient_summaries_with_llms 40 | 41 | # Run the application 42 | echo "Running script" 43 | echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" 44 | # Set device with CUDA_VISIBLE_DEVICES 45 | device="cuda" 46 | echo "Device: $device" 47 | 48 | # General 49 | model="allenai/led-large-16384" 50 | # Cluster 51 | project="/home/s/s_hegs02/scratch/mimic-iv-note-di-bhc" 52 | # Local 53 | # project="/home/s_hegs02/mimic-iv-note-di-bhc" 54 | data_path="${project}/dataset" 55 | output_path=$1 56 | 57 | # Experiment 58 | max_steps="200000" 59 | save_and_logging_steps="20000" 60 | 61 | # General 62 | batch_size="1" 63 | 64 | # Parameters 65 | # Default parameters 66 | # dropout="0.1" 67 | # learning_rate="5e-5" 68 | dropout=$2 69 | learning_rate=$3 70 | 71 | python summarization/run_summarization_large_long.py \ 72 | --model_name_or_path ${model} \ 73 | --do_train --do_eval --do_predict \ 74 | --train_file ${data_path}/train.json \ 75 | --validation_file ${data_path}/valid_last_100.json \ 76 | --test_file ${data_path}/valid_last_100.json \ 77 | --output_dir ${output_path} \ 78 | --max_steps ${max_steps} \ 79 | --evaluation_strategy steps \ 80 | --eval_steps ${save_and_logging_steps} \ 81 | --save_steps ${save_and_logging_steps} \ 82 | --load_best_model_at_end \ 83 | --per_device_train_batch_size=${batch_size} \ 84 | --per_device_eval_batch_size=${batch_size} \ 85 | --dropout ${dropout} \ 86 | --learning_rate ${learning_rate} \ 87 | --predict_with_generate \ 88 | --max_source_length 4096 \ 89 | --max_target_length 350 -------------------------------------------------------------------------------- /gpt-4/run_all.sh: -------------------------------------------------------------------------------- 1 | # Initial Debugging round 2 | # python run_summarization.py --task_id 4 --model_name gpt-4 --n_shot 3 --verbose --debug 3 | # python run_summarization.py --task_id 5 --model_name gpt-4 --n_shot 3 --verbose --debug 4 | # python run_summarization.py --task_id 6 --model_name gpt-4 --n_shot 3 --verbose --debug 5 | 6 | # Generating intermediate results for picking the best prompt -- see the appendix table 7 | # python run_summarization.py --task_id 4 --prompt_id 1 --model_name gpt-4 --n_shot 1 --verbose --debug 8 | # python run_summarization.py --task_id 4 --prompt_id 1 --model_name gpt-4 --n_shot 3 --verbose --debug 9 | # python run_summarization.py --task_id 4 --prompt_id 1 --model_name gpt-4 --n_shot 5 --verbose --debug 10 | # python run_summarization.py --task_id 4 --prompt_id 2 --model_name gpt-4 --n_shot 1 --verbose --debug 11 | # python run_summarization.py --task_id 4 --prompt_id 2 --model_name gpt-4 --n_shot 3 --verbose --debug 12 | # python run_summarization.py --task_id 4 --prompt_id 2 --model_name gpt-4 --n_shot 5 --verbose --debug 13 | # python run_summarization.py --task_id 4 --prompt_id 3 --model_name gpt-4 --n_shot 1 --verbose --debug 14 | # python run_summarization.py --task_id 4 --prompt_id 3 --model_name gpt-4 --n_shot 3 --verbose --debug 15 | # python run_summarization.py --task_id 4 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose --debug 16 | 17 | # Run all tasks for annotation 18 | python run_summarization.py --task_id 4 --prompt_id 3.1 --model_name gpt-4 --n_shot 0 --verbose 19 | python run_summarization.py --task_id 4 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose 20 | python run_summarization.py --task_id 5 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose 21 | python run_summarization.py --task_id 6 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose 22 | 23 | # Run all tasks for exp 1 and 2 24 | python run_summarization.py --task_id 1 --prompt_id 3.1 --model_name gpt-4 --n_shot 0 --verbose 25 | python run_summarization.py --task_id 1 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose 26 | python run_summarization.py --task_id 2 --prompt_id 3.1 --model_name gpt-4 --n_shot 0 --verbose 27 | python run_summarization.py --task_id 2 --prompt_id 3 --model_name gpt-4 --n_shot 5 --verbose 28 | 29 | # Run to generate examples for the paper 30 | # cd gpt-4/summarization_data 31 | # touch paper_6_test.json # Based on Stefan's example in slack 32 | # ln -s $(realpath exp_6_in-context.json) paper_6_in-context.json 33 | 34 | python run_summarization.py --task_id 6 --prompt_id 3.1 --model_name gpt-4 --n_shot 0 --what_for paper --verbose 35 | python run_summarization.py --task_id 6 --prompt_id 3 --model_name gpt-4 --n_shot 5 --what_for paper --verbose -------------------------------------------------------------------------------- /scripts/parameter_tuning_llama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=parameter_tuning 4 | #SBATCH --output=/home/s/s_hegs02/logs/parameter_tuning-%J.log 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=1 7 | 8 | # GPU - LLAMA 7B 9 | # --time=0-12:00:00 10 | # --partition=gpuhgx 11 | # --cpus-per-task=4 12 | # --gres=gpu:1 13 | # --mem=50GB 14 | 15 | # GPU - LLAMA 70B 16 | #SBATCH --time=0-24:00:00 17 | #SBATCH --partition=gpuhgx 18 | #SBATCH --cpus-per-task=8 19 | #SBATCH --gres=gpu:2 20 | #SBATCH --mem=80GB 21 | 22 | echo "------------------------------------------------------------" 23 | echo "SLURM JOB ID: $SLURM_JOBID" 24 | echo "Running on nodes: $SLURM_NODELIST" 25 | echo "------------------------------------------------------------" 26 | 27 | # Load the JupyterLab module 28 | ml palma/2022a 29 | ml GCCcore/11.3.0 30 | # ml CUDA/11.7.0 31 | 32 | # Load conda 33 | source /home/s/s_hegs02/.bashrc-slurm 34 | 35 | # Load environment 36 | conda activate ps_llm 37 | echo "Conda environment: $CONDA_DEFAULT_ENV" 38 | 39 | # Change path 40 | # cd /home/s/s_hegs02/patient_summaries_with_llms 41 | 42 | # Run the application 43 | echo "Running script" 44 | echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" 45 | # echo ${SLURM_STEP_GPUS:-$SLURM_JOB_GPUS} - if SLURM_STEP_GPUS is not set, use SLURM_JOB_GPUS 46 | echo "SLURM GPUs: $SLURM_JOB_GPUS" 47 | # echo "Set CUDA_VISIBLE_DEVICES: $SLURM_JOB_GPUS" 48 | 49 | # Set device with CUDA_VISIBLE_DEVICES 50 | device="cuda" 51 | echo "Device: $device" 52 | 53 | # General 54 | model="meta-llama/Llama-2-70b-hf" 55 | # Cluster 56 | project="/home/s/s_hegs02/scratch/mimic-iv-note-di-bhc" 57 | # Local 58 | # project="/home/s_hegs02/mimic-iv-note-di-bhc" 59 | data_path="${project}/dataset" 60 | output_path=$1 61 | 62 | # Experiment 63 | num_train_examples="100" 64 | num_val_examples="100" 65 | num_test_examples="100" 66 | max_steps="100" 67 | save_and_logging_steps="10" 68 | 69 | # General 70 | batch_size="1" 71 | gradient_accumulation_steps="16" 72 | 73 | # Parameters 74 | # Default parameters 75 | # lora_rank="8" 76 | # lora_alpha="8" 77 | # lora_dropout="0.1" 78 | # num_target_modules="4" 79 | # learning_rate="5e-4" 80 | lora_rank=$2 81 | lora_alpha=$3 82 | lora_dropout=$4 83 | num_target_modules=$5 84 | learning_rate=$6 85 | 86 | python summarization/fine_tune_llama.py \ 87 | --model_name_or_path ${model} \ 88 | --data_path ${data_path} \ 89 | --output_path ${output_path} \ 90 | --device ${device} \ 91 | --max_steps ${max_steps} \ 92 | --save_and_logging_steps ${save_and_logging_steps} \ 93 | --batch_size ${batch_size} \ 94 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 95 | --lora_rank ${lora_rank} \ 96 | --lora_alpha ${lora_alpha} \ 97 | --lora_dropout ${lora_dropout} \ 98 | --num_target_modules ${num_target_modules} \ 99 | --learning_rate ${learning_rate} \ 100 | --num_train_examples ${num_train_examples} \ 101 | --num_val_examples ${num_val_examples} \ 102 | --num_test_examples ${num_test_examples} \ 103 | -------------------------------------------------------------------------------- /labeling/convert_jsonl_medtator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import glob 4 | import re 5 | import pandas as pd 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | "--mode", 12 | type=str, 13 | default=None, 14 | help='The mode to run the script in: "jsonl_to_txt_files", "txt_files_to_jsonl".', 15 | ) 16 | parser.add_argument( 17 | "--input", 18 | type=str, 19 | default=None, 20 | help='The input file pattern. Must be either "*.jsonl" or a prefix of txt files.', 21 | ) 22 | parser.add_argument( 23 | "--output", 24 | type=str, 25 | default=None, 26 | help='The output file path. Must be either "*.jsonl" or a prefix of txt files.', 27 | ) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | 35 | # Convert the file. 36 | if args.mode == 'jsonl_to_txt_files': 37 | jsonl_to_txt_files(args.input, args.output) 38 | elif args.mode == 'txt_files_to_jsonl': 39 | txt_files_to_jsonl(args.input, args.output) 40 | else: 41 | print('Invalid mode. Must be either "jsonl_to_txt_files" or "txt_files_to_jsonl".') 42 | 43 | def jsonl_to_txt_files(input, output): 44 | # Read the JSONL file. 45 | with open(input, 'r') as f: 46 | jsonl = f.read().splitlines() 47 | 48 | # Convert the JSONL to text files. 49 | for i, json_str in enumerate(jsonl): 50 | # Convert the JSON string to a dictionary. 51 | json_dict = json.loads(json_str) 52 | 53 | # Get the text from the dictionary. 54 | text = '' 55 | for key, value in json_dict.items(): 56 | # Add key with separot and value to text. 57 | text += f'### JSON Key: {key}\n{value}\n\n' 58 | 59 | # Write the text to a text file index by i filled with leading zeros. 60 | num_leading_zeros = len(str(len(jsonl))) 61 | with open(f'{output}_{str(i).zfill(num_leading_zeros)}.txt', 'w') as f: 62 | f.write(text) 63 | 64 | 65 | def txt_files_to_jsonl(input, output): 66 | # Read the text files starting with input and ending with .txt using glob. 67 | text_files = [] 68 | for text_file in glob.glob(f'{input}*.txt'): 69 | text_files.append(text_file) 70 | text_files.sort() 71 | 72 | # Convert the text files to a JSONL file. 73 | json_dicts = [] 74 | for text_file in text_files: 75 | # Read the text file. 76 | with open(text_file, 'r') as f: 77 | text = f.read() 78 | 79 | # Convert the text to JSON key and values. 80 | # Use regex to split the text into key and value pairs. 81 | json_dict = {} 82 | for key, value in re.findall(r'### JSON Key: (.*)\n(.*)\n\n', text): 83 | json_dict[key] = value 84 | 85 | json_dicts.append(json_dict) 86 | 87 | # Write the JSONL file. 88 | dataframe = pd.DataFrame(json_dicts) 89 | dataframe.to_json(output, orient='records', lines=True) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | summarization_results* 2 | config.yaml 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # wandb 135 | wandb 136 | 137 | # vim 138 | *.lock 139 | 140 | # Custom files 141 | .vscode 142 | publish_instructions.md 143 | revised_examples/revised_examples_hallucination_free.jsonl 144 | revised_examples/revised_examples_orig.jsonl 145 | revised_examples/revised_examples.txt 146 | 147 | summarization_data/ 148 | summarization_data.zip 149 | 150 | hallucination_detection_data.zip 151 | data/hallucination_evaluation/ 152 | data/qualitative_evaluation/ 153 | prompt_tuning/ 154 | performance_results/ 155 | 156 | data/ 157 | 158 | gpt-4/hallucination_detection_data/ 159 | gpt-4/summarization_data/ 160 | gpt-4/summarization_data/README.md 161 | 162 | revised_examples/ -------------------------------------------------------------------------------- /labeling/remove_json_key_from_bioc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | from bioc import biocxml 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--input_file", 10 | type=str, 11 | # default="/home/s_hegs02/MedTator/13_agreed_label_silver_validation_examples/hallucinations_10_valid_mimic_agreed_old_key.xml", 12 | default="/home/s_hegs02/MedTator/12_agreed_label_silver_examples/hallucinations_100_mimic_agreed_old_key.xml", 13 | ) 14 | parser.add_argument( 15 | "--output_file", 16 | type=str, 17 | # default="/home/s_hegs02/MedTator/13_agreed_label_silver_validation_examples/hallucinations_10_valid_mimic_agreed.xml", 18 | default="/home/s_hegs02/MedTator/12_agreed_label_silver_examples/hallucinations_100_mimic_agreed.xml", 19 | ) 20 | args = parser.parse_args() 21 | return args 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | # Read dataset 27 | def read_bioc(path): 28 | with open(path, 'r') as fp: 29 | return biocxml.load(fp) 30 | 31 | input = read_bioc(args.input_file) 32 | 33 | # Replace keys 34 | source_text = '### JSON Key: text\n' 35 | target_text = 'Text:\n' 36 | source_summary = '### JSON Key: summary\n' 37 | target_summary = 'Summary:\n' 38 | count_text = 0 39 | count_summary = 0 40 | 41 | for document in input.documents: 42 | for passage in document.passages: 43 | count_text += passage.text.count(source_text) 44 | count_summary += passage.text.count(source_summary) 45 | passage.text = passage.text.replace(source_text, target_text) 46 | passage.text = passage.text.replace(source_summary, target_summary) 47 | 48 | print(f"Replaced {count_text} occurrences of '{source_text}' with '{target_text}'") 49 | print(f"Replaced {count_summary} occurrences of '{source_summary}' with '{target_summary}'") 50 | 51 | # Now fix offsets of label 52 | # Find digits in expression ' offset="123"/>' 53 | # Replace with digits - length change 54 | length_change = (len(target_text) - len(source_text)) + (len(target_summary) - len(source_summary)) 55 | 56 | def sum_all_offsets(input_text): 57 | re_offset = re.compile(r' offset="\d+"') 58 | offsets = re_offset.findall(input_text) 59 | return sum([int(offset.split('"')[1]) for offset in offsets]) 60 | 61 | old_sum = sum_all_offsets(str(biocxml.dumps(input))) 62 | print(f"Old sum of all offsets: {old_sum}") 63 | 64 | count_offset = 0 65 | for document in input.documents: 66 | for passage in document.passages: 67 | for annotation in passage.annotations: 68 | for location in annotation.locations: 69 | # Get offset 70 | offset = int(location.offset) 71 | # Change offset 72 | location.offset = str(offset + length_change) 73 | count_offset += 1 74 | 75 | # Check if sum of all offsets is correct 76 | assert old_sum + count_offset * length_change == sum_all_offsets(str(biocxml.dumps(input))) 77 | print(f"Changed {count_offset} offsets by {length_change}") 78 | 79 | # Write to output file 80 | with open(args.output_file, 'w') as f: 81 | biocxml.dump(input, f) 82 | 83 | 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /labeling/README.md: -------------------------------------------------------------------------------- 1 | # Labeling Errors in Patient Summaries 2 | 3 | ## Dataset Subselection 4 | 5 | To select more relevant data and improve labeling quality, we subselect the dataset. 6 | We filtered for reference texts with at most 4000 characters and summaries with at least 600 characters. 7 | This can be done with the following commands: 8 | 9 | ``` 10 | python /root/patient_summaries_with_llms/labeling/select_dataset_subset_with_chars.py --input_file /root/mimic-iv-note-di-bhc/dataset/train.json --output_dir /root/mimic-iv-note-di-bhc/dataset --text_max_chars 4000 --summary_min_chars 600 11 | Writing 20931 text-summary pairs to /root/mimic-iv-note-di-bhc/dataset/train_4000_600_chars.json 12 | 13 | python /root/patient_summaries_with_llms/labeling/select_dataset_subset_with_chars.py --input_file /root/mimic-iv-note-di-bhc/dataset/valid.json --output_dir /root/mimic-iv-note-di-bhc/dataset --text_max_chars 4000 --summary_min_chars 600 14 | Writing 2608 text-summary pairs to /root/mimic-iv-note-di-bhc/dataset/valid_4000_600_chars.json 15 | 16 | python /root/patient_summaries_with_llms/labeling/select_dataset_subset_with_chars.py --input_file /root/mimic-iv-note-di-bhc/dataset/test.json --output_dir /root/mimic-iv-note-di-bhc/dataset --text_max_chars 4000 --summary_min_chars 600 17 | Writing 2639 text-summary pairs to /root/mimic-iv-note-di-bhc/dataset/test_4000_600_chars.json 18 | ``` 19 | 20 | Of these we selected 250 training examples for labeling and 100 training examples for doing parameter tuning. 21 | 22 | ``` 23 | cd /root/mimic-iv-note-di-bhc/dataset 24 | cat train_4000_600_chars.json | awk 'NR >= 0 && NR <= 250 { print }' > train_4000_600_chars_250_labeling.json 25 | cat train_4000_600_chars.json | awk 'NR >= 251 && NR <= 350 { print }' > train_4000_600_chars_251-350_pt.json 26 | ``` 27 | 28 | ## Creating MedTator Datset 29 | 30 | To convert the 250 labeling examples into the right format for MedTator, we run the following command: 31 | 32 | ``` 33 | mkdir -p /root/medtator/data 34 | python /root/patient_summaries_with_llms/labeling/convert_jsonl_medtator.py --mode jsonl_to_txt_files --input /root/mimic-iv-note-di-bhc/dataset/train_4000_600_chars_250_labeling.json --output /root/medtator/data/train_4000_600_chars_250 35 | ``` 36 | 37 | This will create the files `train_4000_600_chars_250_*.txt` in the directory `/root/medtator/data/` for all 250 examples. 38 | We used these for our labeling task in MedTator. 39 | 40 | ## Creating Hallucination Free and Revised Datasets 41 | 42 | Based on the labeling results, we created three new versions of the dataset. 43 | The first one contains the original summaries, the second one the original summaries with hallucination removed, and the third contains the original summaries with hallucination removed and revised summaries. 44 | We stored them in a .txt file with the format `id:\n\nhallucination free summary\n\nrevised summary\n\n`. 45 | We provide the .txt file in our repository at physionet.org. 46 | To create the datasets from this file and store them in the `dataset` folder run: 47 | 48 | ``` 49 | python /root/patient_summaries_with_llms/labeling/create_revised_dataset.py --input_file_examples /root/mimic-iv-note-di-bhc/dataset/train_4000_600_chars_250_labeling.json --input_file_revised_examples_txt /root/MedTator/revised_examples.txt --output_dir /root/mimic-iv-note-di-bhc/dataset/ --excluded_ids 0,1,2,3,4,5,6,7,8,9,11,12 50 | Read 112 examples from /home/s/s_hegs02/scratch/MedTator/MedTator/10_label_silver_examples_annotator_1/revised_examples.txt 51 | Read 112 according examples from /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset/train_4000_600_chars_250_labeling.json 52 | Exlcuded 12 examples. 100 examples remaining. 53 | Wrote datasets with 100 examples to /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset 54 | ``` 55 | 56 | ## Analyzing Labeling Results 57 | 58 | We used the notebook `analyse_labelings.ipynb` to analyse the labeling results. 59 | 60 | ## TODO 61 | * Continue adding step into MedTator and back into dataset -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Data-Centric Approach To Generate Faithful and High Quality Patient Summaries with Large Language Models 2 | 3 | ![Figure1-5](https://github.com/user-attachments/assets/fa631f08-9e56-4a37-aea3-3b46fd6d31ef) 4 | 5 | This repository contains the code to reproduce the results of the paper [A Data-Centric Approach To Generate Faithful and High Quality Patient Summaries with Large Language Models](https://proceedings.mlr.press/v248/hegselmann24a.html) by Stefan Hegselmann, Shannon Zejiang Shen, Florian Gierse, Monica Agrawal, David Sontag, and Xiaoyi Jiang. 6 | 7 | We released the 100 doctor-written summaries from the MIMIC-IV-Note Discharge Instructions and hallucinations 100 LLM-generated patient summaries annotated for unsupported facts by two medical experts on PhysioNet. We also published all datasets created in our work to fully reproduce our experiments. 8 | 9 | If you consider our work helpful or use our datasets, please consider the citations for our paper and PhysioNet repository: 10 | 11 | ``` 12 | @InProceedings{pmlr-v248-hegselmann24a, 13 | title = {A Data-Centric Approach To Generate Faithful and High Quality Patient Summaries with Large Language Models}, 14 | author = {Hegselmann, Stefan and Shen, Zejiang and Gierse, Florian and Agrawal, Monica and Sontag, David and Jiang, Xiaoyi}, 15 | booktitle = {Proceedings of the fifth Conference on Health, Inference, and Learning}, 16 | pages = {339--379}, 17 | year = {2024}, 18 | volume = {248}, 19 | series = {Proceedings of Machine Learning Research}, 20 | month = {27--28 Jun}, 21 | publisher = {PMLR}, 22 | url = {https://proceedings.mlr.press/v248/hegselmann24a.html}, 23 | } 24 | 25 | @Misc{hegselmann_ann-pt-summ2024, 26 | title = {Medical Expert Annotations of Unsupported Facts in {Doctor}-{Written} and LLM-Generated Patient Summaries}, 27 | author = {Hegselmann, Stefan and Shen, Zejiang and Gierse, Florian and Agrawal, Monica and Sontag, David and Jiang, Xiaoyi}, 28 | booktitle = {Proceedings of the fifth Conference on Health, Inference, and Learning}, 29 | year = {2024}, 30 | publisher = {PhysioNet}, 31 | url = {https://physionet.org/content/ann-pt-summ/1.0.0/}, 32 | doi = {https://doi.org/10.13026/a66y-aa53}, 33 | } 34 | ``` 35 | 36 | ## Overview 37 | 38 | Here you will find the general procedures to setup the environment, download the data, and run the code. 39 | More detailed instructions for each component of the project can be found in the respective folders. 40 | 41 | * [gpt-4](gpt-4/README.md): All code related to the GPT-4 experiments. 42 | * [hallucination_detection](hallucination_detection/README.md): All code related to the hallucination detection experiments without gpt-4. 43 | * [labeling](labeling/README.md): Scripts to analyse and work with labeling data created with MedTator. 44 | * [notebooks](notebooks/README.md): Jupyter notebooks for different experiments, helper tasks, and analyses. 45 | * [preprocess](preprocess/README.md): Preprocessing pipeline as presented in the paper. 46 | * [scripts](scripts/README.md): Scripts for parameter tuning of LED and LLama 2 models. 47 | * [summarization](summarization/README.md): All code related to the summarization experiments with LED and Llama 2 models. 48 | 49 | 50 | ## Setting Correct Paths 51 | 52 | We assume the root path to be `/root` in this readme and for the code. 53 | Hence, we assume the repository is cloned to `/root/patient_summaries_with_LLMs`. 54 | Please adapt the paths according to your local setup. 55 | 56 | 57 | ## Preparing the Environment 58 | 59 | We used conda to create the necessary virtual environments. For the `ps_llms` environment, we used python 3.9.18: 60 | 61 | ``` 62 | conda create -n ps_llms python==3.9.18 63 | conda activate ps_llms 64 | ``` 65 | 66 | Next, install the nevessary requirements. For installing `torch` you might adapt the command in the first line based on [this suggestion](https://pytorch.org). 67 | 68 | ``` 69 | pip install torch torchvision torchaudio 70 | pip install transformers bitsandbytes sentencepiece accelerate datasets peft trl py7zr scipy wandb evaluate rouge-score sacremoses sacrebleu seqeval bert_score swifter bioc medcat plotly nervaluate nbformat kaleido 71 | pip install -U spacy 72 | python -m spacy download en_core_web_sm 73 | ``` 74 | -------------------------------------------------------------------------------- /labeling/create_revised_dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from pathlib import Path 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--input_file_examples", 10 | type=str, 11 | default=None, 12 | ) 13 | parser.add_argument( 14 | "--input_file_revised_examples_txt", 15 | type=str, 16 | default=None, 17 | ) 18 | parser.add_argument( 19 | "--output_dir", 20 | type=str, 21 | default=None, 22 | ) 23 | parser.add_argument( 24 | "--excluded_ids", 25 | type=str, 26 | default="", 27 | ) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | # Read file 35 | with open(args.input_file_examples, "r") as f: 36 | file_examples = f.readlines() 37 | with open(args.input_file_revised_examples_txt, "r") as f: 38 | file_revised_examples = f.readlines() 39 | 40 | 41 | # Convert revised examples to dictionary 42 | # Find excluded examples 43 | # TODO: Does not match example 44 that follows another excluded example 44 | # regex = re.compile(r"(\d+): (.*)\n\n\d+:", re.MULTILINE) 45 | # for match in regex.finditer("".join(file_revised_examples)): 46 | # # print(match.groups()) 47 | # print(f"Excluded example {match.group(1)}: {match.group(2)}") 48 | 49 | # Find labeled example 50 | revised_dict = {} 51 | regex = re.compile(r"(\d+):(.*)\n\n([a-zA-Z_].*)\n\n([a-zA-Z_].*)\n", re.MULTILINE) 52 | for match in regex.finditer("".join(file_revised_examples)): 53 | # print(match.groups()) 54 | revised_dict[int(match.group(1))] = {"comment": match.group(2), "hallucination_free_summary": match.group(3), "revised_summary": match.group(4)} 55 | print(f"Read {len(revised_dict)} examples from {args.input_file_revised_examples_txt}") 56 | 57 | # Source file is jsonl format with field "text" and "summary" 58 | # Read in the lines of the source file according to the keys in revised_dict and add the entries for "text" and "summary" to revised_dict 59 | for i, line in enumerate(file_examples): 60 | if i in revised_dict.keys(): 61 | line = json.loads(line) 62 | revised_dict[i]["text"] = line["text"] 63 | revised_dict[i]["summary"] = line["summary"] 64 | print(f"Read {len(revised_dict)} according examples from {args.input_file_examples}") 65 | 66 | # Exclude ids excluding bad examples 67 | num_deleted = 0 68 | if args.excluded_ids != "": 69 | for index in sorted(args.excluded_ids.split(","), reverse=True): 70 | del revised_dict[int(index)] 71 | num_deleted += 1 72 | print(f"Exlcuded {num_deleted} examples. {len(revised_dict)} examples remaining.") 73 | else: 74 | print(f"No examples excluded. {len(revised_dict)} examples remaining.") 75 | 76 | 77 | # Write out "text" and "summary" to jsonl file 78 | revised_jsonl = [] 79 | output_path = Path(args.output_dir) 80 | for i in revised_dict.keys(): 81 | revised_jsonl.append(json.dumps({"text": revised_dict[i]["text"], "summary": revised_dict[i]["summary"]})) 82 | with open(output_path / 'hallucination_summaries_original.json', "w") as f: 83 | f.write("\n".join(revised_jsonl)) 84 | # Write out "text" and "hallucination_free_summary" to jsonl file 85 | revised_jsonl = [] 86 | for i in revised_dict.keys(): 87 | revised_jsonl.append(json.dumps({"text": revised_dict[i]["text"], "summary": revised_dict[i]["hallucination_free_summary"]})) 88 | with open(output_path / 'hallucination_summaries_cleaned.json', "w") as f: 89 | f.write("\n".join(revised_jsonl)) 90 | # Write out "text" and "revised_summary" to jsonl file 91 | revised_jsonl = [] 92 | for i in revised_dict.keys(): 93 | revised_jsonl.append(json.dumps({"text": revised_dict[i]["text"], "summary": revised_dict[i]["revised_summary"]})) 94 | with open(output_path / 'hallucination_summaries_cleaned_improved.json', "w") as f: 95 | f.write("\n".join(revised_jsonl)) 96 | print(f"Wrote datasets with {len(revised_dict)} examples to {output_path}") 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import string 2 | import spacy 3 | import nltk 4 | from nltk.util import ngrams 5 | import pandas as pd 6 | import numpy as np 7 | from spacy.symbols import ORTH 8 | from collections import Counter 9 | from sklearn.feature_extraction.text import TfidfVectorizer 10 | from scipy.sparse import lil_matrix 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | from sklearn.preprocessing import binarize 13 | from tqdm import tqdm 14 | from src.preprocess.regular_expressions import * 15 | 16 | 17 | def get_pairwise_text_similarity(texts, batch_size=1000, threshold=0): 18 | """ Calculates the pairwise cosine similarity between the texts. Use batch sizes and threshold to reduce memory usage.""" 19 | vectorizer = TfidfVectorizer(encoding='utf-8', strip_accents='unicode', stop_words='english') # , max_features=1000) 20 | 21 | X = vectorizer.fit_transform(texts).astype(np.float32) 22 | S = lil_matrix((len(texts), len(texts)), dtype=np.float32) 23 | # Determine cosine_similarity in batches and remove unnecessary values via thresholding 24 | pbar = tqdm(total=len(range(0, len(texts), batch_size))) 25 | for i in range(0, len(texts), batch_size): 26 | S[:, i:i+batch_size] = binarize(cosine_similarity(X, X[i:i+batch_size], dense_output=False), threshold=threshold) 27 | pbar.update(1) 28 | 29 | return S.tocsr() 30 | 31 | 32 | def get_most_frequent_words(texts, limit=None): 33 | """ Use spacy to count the n most frequent words in the texts. """ 34 | if limit is not None: 35 | texts = texts.sample(limit, replace=True) 36 | nlp = spacy.load('en_core_web_sm') 37 | docs = [doc for doc in nlp.pipe(texts.tolist(), n_process=4)] 38 | counts = [doc.count_by(ORTH).items() for doc in docs] 39 | words = [word for count in counts for word, _ in count] 40 | counts = [count for count in counts for _, count in count] 41 | df_counts = pd.DataFrame({'word': words, 'count': counts}) 42 | # Group by words and keep word 43 | df_counts = df_counts.groupby('word', as_index=False).sum().sort_values('count', ascending=False) 44 | df_counts['word'] = df_counts['word'].apply(lambda x: nlp.vocab.strings[x]) 45 | # Sort 46 | df_counts = df_counts.sort_values('count', ascending=False, ignore_index=True) 47 | return df_counts 48 | 49 | 50 | def get_overlapping_ngram_spans(texts, n_gram_length=20, n_gram_min_occurence=20): 51 | """ Detects spans of ngrams that occur frequently in the texts.""" 52 | tokenize = lambda x: nltk.word_tokenize(x) 53 | 54 | tokens = tokenize((' ' + 10*'#' + ' ').join(texts)) 55 | n_grams = ngrams(tokens, n_gram_length) 56 | fdist = nltk.FreqDist(n_grams) 57 | frequent_n_grams = set(ngram for ngram, count in fdist.items() if count > n_gram_min_occurence) 58 | 59 | # Detect spans of ngrams in texts 60 | duplicate_spans = Counter() 61 | for summ_tokens in [list(tokenize(summ)) for summ in texts]: 62 | last_start_duplicate = -1 63 | for i in range(0, len(summ_tokens) - n_gram_length): 64 | ngram = tuple(summ_tokens[i:i+n_gram_length]) 65 | if ngram in frequent_n_grams: 66 | if last_start_duplicate == -1: 67 | # Begin of new duplicate span 68 | last_start_duplicate = i 69 | if not (ngram in frequent_n_grams) or i == len(summ_tokens) - n_gram_length - 1: 70 | # End of duplicate span 71 | if last_start_duplicate > -1: 72 | duplicate_spans[tuple(summ_tokens[last_start_duplicate:i+n_gram_length-1])] += 1 73 | last_start_duplicate = -1 74 | return duplicate_spans 75 | 76 | 77 | def split_into_paragraphs(summary): 78 | # If paragraph smaller than minimum length combin with previous paragraph 79 | min_num_words = 12 80 | anonymization = ['___'] 81 | punctuation = list(string.punctuation) 82 | # TODO: Could add stemming 83 | # TODO: Could add stopwords 84 | # stopwords = nltk.corpus.stopwords.words('english') 85 | ignore_words = punctuation + anonymization 86 | tokenize = lambda x: [res for res in [t.lower() for t in nltk.word_tokenize(x)] if res not in ignore_words] 87 | paragraphs = re_paragraph.split(summary) 88 | paragraphs.reverse() 89 | for i in range(len(paragraphs)-1, 0, -1): 90 | if len(tokenize(paragraphs[i])) < min_num_words: 91 | paragraphs[i-1] = paragraphs[i] + '\n\n' + paragraphs[i-1] 92 | paragraphs = paragraphs[:i] + paragraphs[i+1:] 93 | # Check first paragraph 94 | if len(paragraphs) > 1 and len(tokenize(paragraphs[0])) < min_num_words: 95 | paragraphs[1] = paragraphs[1] + '\n\n' + paragraphs[0] 96 | paragraphs = paragraphs[1:] 97 | paragraphs.reverse() 98 | return paragraphs 99 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Create MIMIC-IV-Note-DI Dataset 2 | 3 | ## Prepare the MIMIC-IV-Note Database 4 | 5 | We used the data from the [MIMIC-IV-Note](https://physionet.org/content/mimic-iv-note/2.2/) dataset version 2.2. 6 | Please applpy for access and download the data. 7 | We assume the dataset is located in `/root/physionet.org/files/mimic-iv-note`. 8 | Extract the files and move the database to the `/root` path to simplify the use of our scripts. 9 | 10 | ``` 11 | gunzip /root/physionet.org/files/mimic-iv-note/2.2/note/*.gz 12 | mv /root/physionet.org/files/mimic-iv-note /root 13 | rm -r /root/physionet.org 14 | ``` 15 | 16 | ## Process the MIMIC-IV Summaries 17 | 18 | First, create a folder to store the newly created dataset in. 19 | Then, we set the `PYTHONPATH` to the repository and exceute the preprocessing script. 20 | The scripts goes through several steps. 21 | Using `--start_from_step 1` will start it from scratch. 22 | 23 | ``` 24 | mkdir /root/mimic-iv-note-di 25 | mkdir /root/mimic-iv-note-di/dataset 26 | 27 | export PYTHONPATH=/root/patient_summaries_with_llms 28 | cd /root/patient_summaries_with_llms 29 | python /root/patient_summaries_with_llms/preprocess/process_mimic_summaries.py \ 30 | --start_from_step 1 \ 31 | --input_file /home/s/s_hegs02/scratch/mimic-iv-note/2.2/note/discharge.csv \ 32 | --output_dir /home/s/s_hegs02/scratch/mimic-iv-note-di/dataset 33 | ``` 34 | 35 | You should get an output starting with: 36 | 37 | ``` 38 | Found total of 331793 texts 39 | 40 | Step 1: Remove exact duplicates, and only keep most recent note per hospital stay. 41 | Removed 0 / 331793 exact duplicates. 42 | Removed 0 / 331793 notes from same hospital stay. 43 | Pandas Apply: 100%|█████████████████████████████████████████████████████████████████████████| 331793/331793 [00:01<00:00, 326553.88it/s] 44 | Pandas Apply: 100%|████████████████████████████████████████████████████████████████████████████████| 318/318 [00:00<00:00, 91795.50it/s] 45 | 46 | [...] 47 | 48 | Total entries: 100175. 49 | 50 | Output data to /root/mimic-iv-note-di/dataset 51 | ``` 52 | 53 | The resulting csv still contains all comlumns of the MIMIC-IV-Note database and should be stored in the `output_dir`. 54 | Based on this we can create different datasets using the complete hopsital course or only the brief hospital course as a reference. 55 | We will create both here. 56 | 57 | ## Select Dataset Columns and Create Splits 58 | 59 | To select the relevant colums and create dataset splits, we use a separate script `split_dataset.py`. 60 | The following command will use the full `hospital course` as a reference and the preprocessed `summary` column as summary. 61 | 62 | ``` 63 | python /root/patient_summaries_with_llms/preprocess/split_dataset.py \ 64 | --input_file /root/mimic-iv-note-di/dataset/mimic_processed_summaries.csv \ 65 | --output_dir /root/mimic-iv-note-di/dataset \ 66 | --hospital_course_column hospital_course \ 67 | --summary_column summary 68 | 69 | Found total of 100175 texts 70 | Wrote 80140 train, 10017 valid, and 10017 test examples to /root/mimic-iv-note-di/dataset 71 | ``` 72 | 73 | To create a separate version of the dataset using the shorter `brief_hospital_course` as a reference execute: 74 | 75 | ``` 76 | mkdir /root/mimic-iv-note-di-bhc 77 | mkdir /root/mimic-iv-note-di-bhc/dataset 78 | python /root/patient_summaries_with_llms/preprocess/split_dataset.py \ 79 | --input_file /root/mimic-iv-note-di/dataset/mimic_processed_summaries.csv \ 80 | --output_dir /root/mimic-iv-note-di-bhc/dataset \ 81 | --hospital_course_column brief_hospital_course \ 82 | --summary_column summary 83 | Found total of 100175 texts 84 | Wrote 80140 train, 10017 valid, and 10017 test examples to /root/mimic-iv-note-di-bhc/dataset 85 | ``` 86 | 87 | As a consequence, we have the jsonl files `all.json`, `train.json`, `valid.json`, `test.json` in the directories `/root/mimic-iv-note-di/dataset` and `/root/mimic-iv-note-di-bhc/dataset` for the full hospital course and the brief hospital course as references, respectively. 88 | In this work we focus on the brief hospital course. 89 | 90 | ## Create Dataset Embeddings 91 | 92 | To create the t-SNE embeddings of the summaries, we used a Jupyter notebook `visualize_summary_embeddings.ipynb`. 93 | Based on the dataset splits above, one has to adapt the following paths at the top of the notebook. 94 | The file `all_unprocessed.json` contains the unprocessed summaries. 95 | It was obtained using the `split_dataset.py` script on the summaries outputted from the preprocessing directly after the split into hospital course and summary. 96 | Hence, these summaries were not altered. 97 | 98 | ``` 99 | mimic4_unfiltered_path = '/root/mimic-iv-note-di/dataset/all_unprocessed.json' 100 | mimic4_filtered_path = '/root/mimic-iv-note-di/dataset/all.json' 101 | ``` 102 | 103 | There are different ways to label the t-SNE embeddings. 104 | By default the dummy labels of `1` for evert summary can be used. 105 | To create an embedding labeled by the medical service, we abused the `hospital_course_column` option for the dataset splits to hold the medical service instead of the reference text. 106 | 107 | ``` 108 | python /root/patient_summaries_with_llms/preprocess/split_dataset.py \ 109 | --input_file /root/mimic-iv-note-di/dataset/mimic_processed_summaries.csv \ 110 | --output_dir /root/mimic-iv-note-di/dataset \ 111 | --hospital_course_column service 112 | --summary_column summary 113 | ``` 114 | 115 | Assume this was done for the processed summaries `all_services.json` and all unprocessed `all_unprocessed_services.json` set the paths in the notebook to the following an be sure to use the labels for `Medical Services`. 116 | 117 | ``` 118 | mimic4_unfiltered_path = '/root/mimic-iv-note-di/dataset/all_unprocessed_services.json' 119 | mimic4_filtered_path = '/root/mimic-iv-note-di/dataset/all_services.json' 120 | ``` 121 | -------------------------------------------------------------------------------- /notebooks/determine_token_length.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Helper script to determine token size of different texts" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "from transformers import AutoTokenizer\n", 18 | "import pandas as pd" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Set tokenizer\n", 28 | "# Very common tokenizer, also used by MPT models\n", 29 | "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n", 30 | "\n", 31 | "def num_tokens(text):\n", 32 | " return len(tokenizer.tokenize(text))" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Get texts and collect them in a dictionary\n", 42 | "\n", 43 | "df_mimic_iv_avs = pd.read_json('/home/s/s_hegs02/scratch/mimic-iv-note-di/dataset/all.json', orient='records', lines=True)\n", 44 | "df_mimic_iv_avs_brief_hc = pd.read_json('/home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset/all.json', orient='records', lines=True)\n", 45 | "\n", 46 | "# Put text and summary columns in a dict of pandas series\n", 47 | "text_dict = {}\n", 48 | "text_dict['mimic-iv-avs text'] = df_mimic_iv_avs['text']\n", 49 | "text_dict['mimic-iv-avs summary'] = df_mimic_iv_avs['summary']\n", 50 | "text_dict['mimic-iv-avs_brief_hc text'] = df_mimic_iv_avs_brief_hc['text']\n", 51 | "text_dict['mimic-iv-avs_brief_hc summary'] = df_mimic_iv_avs_brief_hc['summary']\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# Determine num tokens for all texts\n", 61 | "for key, value in text_dict.items():\n", 62 | " text_dict[key] = value.apply(num_tokens)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 6, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "mimic-iv-avs text\n", 75 | "count 100175.000000\n", 76 | "mean 4367.086758\n", 77 | "std 1625.275635\n", 78 | "min 554.000000\n", 79 | "25% 3277.000000\n", 80 | "50% 4141.000000\n", 81 | "75% 5180.000000\n", 82 | "max 21542.000000\n", 83 | "Name: text, dtype: float64\n", 84 | "mimic-iv-avs summary\n", 85 | "count 100175.000000\n", 86 | "mean 145.393891\n", 87 | "std 61.365393\n", 88 | "min 67.000000\n", 89 | "25% 103.000000\n", 90 | "50% 128.000000\n", 91 | "75% 169.000000\n", 92 | "max 960.000000\n", 93 | "Name: summary, dtype: float64\n", 94 | "mimic-iv-avs_brief_hc text\n", 95 | "count 100175.000000\n", 96 | "mean 858.606059\n", 97 | "std 498.301372\n", 98 | "min 106.000000\n", 99 | "25% 504.000000\n", 100 | "50% 753.000000\n", 101 | "75% 1096.000000\n", 102 | "max 7768.000000\n", 103 | "Name: text, dtype: float64\n", 104 | "mimic-iv-avs_brief_hc summary\n", 105 | "count 100175.000000\n", 106 | "mean 145.393891\n", 107 | "std 61.365393\n", 108 | "min 67.000000\n", 109 | "25% 103.000000\n", 110 | "50% 128.000000\n", 111 | "75% 169.000000\n", 112 | "max 960.000000\n", 113 | "Name: summary, dtype: float64\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# Print statistics for each series\n", 119 | "for key, value in text_dict.items():\n", 120 | " print(key)\n", 121 | " print(value.describe())" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "2200.0\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "print(text_dict['mimic-iv-avs_brief_hc text'].quantile(0.98))" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 14, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "322.0\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "print(text_dict['mimic-iv-avs_brief_hc summary'].quantile(0.98))" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 15, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "370.0\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "print(text_dict['mimic-iv-avs_brief_hc summary'].quantile(0.99))" 173 | ] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "avs_gen", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.9.18" 193 | }, 194 | "orig_nbformat": 4 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /preprocess/create_dataset_statistics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Create Dataset Statistics" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "from nltk.tokenize import word_tokenize, sent_tokenize\n", 18 | "import statistics\n", 19 | "import json\n", 20 | "from transformers import AutoTokenizer" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 7, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Total entries: 100175\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# Load dataset\n", 38 | "# mimic4_path = '/home/s_hegs02/mimic-iv-note-di/dataset/all.json'\n", 39 | "# Short references (BHC)\n", 40 | "mimic4_path = '/home/s_hegs02/mimic-iv-note-di-bhc/dataset/all.json'\n", 41 | "\n", 42 | "dataset = []\n", 43 | "with open(mimic4_path, 'r') as f:\n", 44 | " for line in f:\n", 45 | " dataset.append(json.loads(line))\n", 46 | " \n", 47 | "# Print total entries\n", 48 | "print(f\"Total entries: {len(dataset)}\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 12, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Select only 10000 notes\n", 58 | "# dataset = dataset[:1000]" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 8, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Load Llama 2 tokenizer to determine number of tokens\n", 68 | "model_name = 'meta-llama/Llama-2-7b-hf'\n", 69 | "hf_token = ''\n", 70 | "\n", 71 | "tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) " 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 9, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "text - mean number of sentences: 33.0\n", 84 | "text - mean number of words: 552.0\n", 85 | "text - mean number of chars: 3029.9\n", 86 | "text - mean number of tokens: 858.6\n", 87 | "text - mean number of deidentified fields: 11.5\n", 88 | "text - std of number of sentences: 19.0\n", 89 | "text - std of number of words: 314.0\n", 90 | "text - std of number of chars: 1736.4\n", 91 | "text - std of number of tokens: 498.3\n", 92 | "text - std of number of deidentified fields: 9.7\n", 93 | "\n", 94 | "summary - mean number of sentences: 6.5\n", 95 | "summary - mean number of words: 113.2\n", 96 | "summary - mean number of chars: 604.4\n", 97 | "summary - mean number of tokens: 145.4\n", 98 | "summary - mean number of deidentified fields: 1.1\n", 99 | "summary - std of number of sentences: 2.6\n", 100 | "summary - std of number of words: 47.4\n", 101 | "summary - std of number of chars: 251.0\n", 102 | "summary - std of number of tokens: 61.4\n", 103 | "summary - std of number of deidentified fields: 1.7\n", 104 | "\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "keys = ['text', 'summary']\n", 110 | "deidentified_field = '___'\n", 111 | "\n", 112 | "for k in keys:\n", 113 | "\n", 114 | " num_sentences = []\n", 115 | " num_words = []\n", 116 | " num_chars = []\n", 117 | " num_tokens = []\n", 118 | " num_deidentified_fields = []\n", 119 | "\n", 120 | " for i in range(len(dataset)):\n", 121 | " entry = dataset[i][k]\n", 122 | " num_sentences.append(len(sent_tokenize(entry)))\n", 123 | " num_words.append(len(word_tokenize(entry)))\n", 124 | " num_chars.append(len(entry))\n", 125 | " num_tokens.append(len(tokenizer.tokenize(entry)))\n", 126 | " num_deidentified_fields.append(entry.count(deidentified_field))\n", 127 | " \n", 128 | " # Determine average and standard deviation using statisitcs module\n", 129 | " # Round by one digit\n", 130 | " print(f\"{k} - mean number of sentences: {statistics.mean(num_sentences):.1f}\")\n", 131 | " print(f\"{k} - mean number of words: {statistics.mean(num_words):.1f}\")\n", 132 | " print(f\"{k} - mean number of chars: {statistics.mean(num_chars):.1f}\")\n", 133 | " print(f\"{k} - mean number of tokens: {statistics.mean(num_tokens):.1f}\")\n", 134 | " print(f\"{k} - mean number of deidentified fields: {statistics.mean(num_deidentified_fields):.1f}\")\n", 135 | " \n", 136 | " print(f\"{k} - std of number of sentences: {statistics.stdev(num_sentences):.1f}\")\n", 137 | " print(f\"{k} - std of number of words: {statistics.stdev(num_words):.1f}\")\n", 138 | " print(f\"{k} - std of number of chars: {statistics.stdev(num_chars):.1f}\")\n", 139 | " print(f\"{k} - std of number of tokens: {statistics.stdev(num_tokens):.1f}\")\n", 140 | " print(f\"{k} - std of number of deidentified fields: {statistics.stdev(num_deidentified_fields):.1f}\")\n", 141 | " print()" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "avs_gen", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.9.18" 162 | }, 163 | "orig_nbformat": 4 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 2 167 | } 168 | -------------------------------------------------------------------------------- /notebooks/evaluate_gpt_4_performance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# General\n", 10 | "import json\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import evaluate\n", 14 | "from rouge_score import rouge_scorer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Use custom rouge function to obtain rouge 3/4 which are not available in huggingface\n", 24 | "def get_rouge_score(gold, pred):\n", 25 | " rouge_scores = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL']\n", 26 | " scorer = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True)\n", 27 | " scores = scorer.score(gold, pred)\n", 28 | " return {k: scores[k].fmeasure * 100 for k in rouge_scores}\n", 29 | "\n", 30 | "def compute_custom_metrics(srcs, golds, preds, device):\n", 31 | " scores = defaultdict(list)\n", 32 | " bertscore = evaluate.load(\"bertscore\")\n", 33 | " sari = evaluate.load(\"sari\")\n", 34 | " \n", 35 | " # For rouge and length go over examples one by one and determine mean\n", 36 | " for gold, pred in zip(golds, preds):\n", 37 | " for k, v in get_rouge_score(gold, pred).items():\n", 38 | " scores[k].append(v)\n", 39 | " scores['words'].append(len(pred.split(' ')))\n", 40 | " for k, v in scores.items():\n", 41 | " scores[k] = np.mean(v)\n", 42 | "\n", 43 | " # This is the default call using model_type=\"roberta-large\"\n", 44 | " # This is the same as in the paper \"Generation of Patient After-Visit Summaries to Support Physicians\" (AVS_gen/eval_summarization.py) using the libary SummerTime\n", 45 | " scores['bert_score'] = np.mean((bertscore.compute(predictions=preds, references=golds, lang=\"en\", device=device))['f1']) * 100\n", 46 | " # BERTScore authors recommend \"microsoft/deberta-large-mnli\" (https://github.com/Tiiiger/bert_score)\n", 47 | " scores['bert_score_deberta-large'] = np.mean((bertscore.compute(predictions=preds, references=golds, device=device, model_type=\"microsoft/deberta-large-mnli\"))['f1']) * 100\n", 48 | " scores['sari'] = sari.compute(sources=srcs, predictions=preds, references=[[g] for g in golds])['sari']\n", 49 | " # scores['sari'] = scores['sari'][0]\n", 50 | " # Importing readability for dallc score not working: https://pypi.org/project/py-readability-metrics/ \n", 51 | "\n", 52 | " return scores\n", 53 | "\n", 54 | "def print_metrics_as_latex(metrics):\n", 55 | " # Print latex table row\n", 56 | " order = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL', 'bert_score', 'bert_score_deberta-large', 'sari', 'words']\n", 57 | " print(' & '.join([f'${metrics[k]:.2f}$' for k in order]))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 12, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# Files\n", 67 | "# test_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_4_test.json\"\n", 68 | "# validation_examples = \"2_short_BHC_summary_prediction/valid_4000_600_chars.json\"\n", 69 | "# preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/gpt-4_exp4_results_3shot.jsonl\"\n", 70 | "# preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/gpt-4_exp6_results_3shot.jsonl\"\n", 71 | "\n", 72 | "# Experiment 1 and 2\n", 73 | "test_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_1_test.json\"\n", 74 | "# preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/performance_results/gpt-4_exp1_results_prompt3.1_0shot.jsonl\"\n", 75 | "preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/performance_results/gpt-4_exp1_results_prompt3_5shot.jsonl\"\n", 76 | "# test_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_2_test.json\"\n", 77 | "# preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/performance_results/gpt-4_exp2_results_prompt3.1_0shot.jsonl\"\n", 78 | "# preds_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/performance_results/gpt-4_exp2_results_prompt3_5shot.jsonl\"\n", 79 | "\n", 80 | "\n", 81 | "\n", 82 | "# Read jsonl files\n", 83 | "def read_jsonl(file_name):\n", 84 | " with open(file_name, \"r\") as f:\n", 85 | " return [json.loads(line) for line in f]\n", 86 | " \n", 87 | "# Read jsonl files\n", 88 | "test_data = read_jsonl(test_data_file)\n", 89 | "preds_data = read_jsonl(preds_data_file)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# Print included valid examples with indices\n", 99 | "for i in range(0, 3):\n", 100 | " print(i)\n", 101 | " print(test_data[i][\"text\"])\n", 102 | " print(test_data[i][\"summary\"])\n", 103 | " print(preds_data[i][\"summary\"])\n", 104 | " print()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 14, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stderr", 114 | "output_type": "stream", 115 | "text": [ 116 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 117 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 118 | ] 119 | }, 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "Test metrics rounded:\n", 125 | "{'rouge1': 38.8, 'rouge2': 10.78, 'rouge3': 3.55, 'rouge4': 1.12, 'rougeL': 21.98, 'words': 131.86, 'bert_score': 86.67, 'bert_score_deberta-large': 61.3, 'sari': 42.88}\n", 126 | "$38.80$ & $10.78$ & $3.55$ & $1.12$ & $21.98$ & $86.67$ & $61.30$ & $42.88$ & $131.86$\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "srcs = [e[\"text\"] for e in test_data]\n", 132 | "golds = [e[\"summary\"] for e in test_data]\n", 133 | "preds = [e[\"summary\"] for e in preds_data]\n", 134 | "metrics_test = compute_custom_metrics(srcs, golds, preds, \"cuda\")\n", 135 | "\n", 136 | "metrics_test = {k: round(v, 2) for k, v in metrics_test.items()}\n", 137 | "print(\"Test metrics rounded:\")\n", 138 | "print(metrics_test)\n", 139 | "print_metrics_as_latex(metrics_test)" 140 | ] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "avs_gen", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.9.18" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 2 164 | } 165 | -------------------------------------------------------------------------------- /notebooks/evalute_gpt_4_prompt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Evaluate GPT 4 prompts" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from pathlib import Path\n", 17 | "import json\n", 18 | "import numpy as np\n", 19 | "from collections import defaultdict\n", 20 | "import evaluate\n", 21 | "from rouge_score import rouge_scorer" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Results for prompt 1 - 1 IC-example, prompt 1 - 3 IC-examples, prompt 1 - 5 IC-examples, prompt 2 - 1 IC-example, prompt 2 - 3 IC-examples, prompt 2 - 5 IC-examples, prompt 3 - 1 IC-example, prompt 3 - 3 IC-examples, prompt 3 - 5 IC-examples\n", 31 | "prefix = Path('/home/s_hegs02/patient_summaries_with_llms/gpt-4/prompt_tuning/')\n", 32 | "files_paths = [\n", 33 | " # 'gpt-4_exp4_results_prompt1_1shot.jsonl',\n", 34 | " # 'gpt-4_exp4_results_prompt1_3shot.jsonl',\n", 35 | " # 'gpt-4_exp4_results_prompt1_5shot.jsonl',\n", 36 | " # 'gpt-4_exp4_results_prompt2_1shot.jsonl',\n", 37 | " # 'gpt-4_exp4_results_prompt2_3shot.jsonl',\n", 38 | " # 'gpt-4_exp4_results_prompt2_5shot.jsonl',\n", 39 | " # # Missing\n", 40 | " # 'gpt-4_exp4_results_prompt3_1shot.jsonl',\n", 41 | " # 'gpt-4_exp4_results_prompt3_3shot.jsonl',\n", 42 | " # 'gpt-4_exp4_results_prompt3_5shot.jsonl',\n", 43 | " 'gpt-4_exp4_results_prompt3_0shot.jsonl',\n", 44 | " 'gpt-4_exp4_results_prompt3_5shot.jsonl',\n", 45 | "]\n", 46 | "\n", 47 | "# Read jsonl files\n", 48 | "def read_jsonl(file_name):\n", 49 | " with open(file_name, \"r\") as f:\n", 50 | " return [json.loads(line) for line in f]\n", 51 | " \n", 52 | "files = [read_jsonl(prefix / file_path) for file_path in files_paths]\n", 53 | "\n", 54 | "test_data_file = \"/home/s_hegs02/patient_summaries_with_llms/gpt-4/summarization_data/exp_4_test.json\"\n", 55 | "test_data = read_jsonl(test_data_file)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Use custom rouge function to obtain rouge 3/4 which are not available in huggingface\n", 65 | "def get_rouge_score(gold, pred):\n", 66 | " rouge_scores = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL']\n", 67 | " scorer = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True)\n", 68 | " scores = scorer.score(gold, pred)\n", 69 | " return {k: scores[k].fmeasure * 100 for k in rouge_scores}\n", 70 | "\n", 71 | "def compute_custom_metrics(srcs, golds, preds, device):\n", 72 | " scores = defaultdict(list)\n", 73 | " bertscore = evaluate.load(\"bertscore\")\n", 74 | " sari = evaluate.load(\"sari\")\n", 75 | " \n", 76 | " # For rouge and length go over examples one by one and determine mean\n", 77 | " for gold, pred in zip(golds, preds):\n", 78 | " for k, v in get_rouge_score(gold, pred).items():\n", 79 | " scores[k].append(v)\n", 80 | " scores['words'].append(len(pred.split(' ')))\n", 81 | " for k, v in scores.items():\n", 82 | " scores[k] = np.mean(v)\n", 83 | "\n", 84 | " # This is the default call using model_type=\"roberta-large\"\n", 85 | " # This is the same as in the paper \"Generation of Patient After-Visit Summaries to Support Physicians\" (AVS_gen/eval_summarization.py) using the libary SummerTime\n", 86 | " scores['bert_score'] = np.mean((bertscore.compute(predictions=preds, references=golds, lang=\"en\", device=device))['f1']) * 100\n", 87 | " # BERTScore authors recommend \"microsoft/deberta-large-mnli\" (https://github.com/Tiiiger/bert_score)\n", 88 | " scores['bert_score_deberta-large'] = np.mean((bertscore.compute(predictions=preds, references=golds, device=device, model_type=\"microsoft/deberta-large-mnli\"))['f1']) * 100\n", 89 | " scores['sari'] = sari.compute(sources=srcs, predictions=preds, references=[[g] for g in golds])['sari']\n", 90 | " # scores['sari'] = scores['sari'][0]\n", 91 | " # Importing readability for dallc score not working: https://pypi.org/project/py-readability-metrics/ \n", 92 | "\n", 93 | " return scores\n", 94 | "\n", 95 | "def get_metrics_as_latex(metrics):\n", 96 | " # Print latex table row\n", 97 | " order = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL', 'bert_score', 'bert_score_deberta-large', 'sari', 'words']\n", 98 | " return ' & '.join([f'${metrics[k]:.2f}$' for k in order])" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "# Print performance\n", 108 | "srcs = [e[\"text\"] for e in test_data][:len(files[0])]\n", 109 | "golds = [e[\"summary\"] for e in test_data][:len(files[0])]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stderr", 119 | "output_type": "stream", 120 | "text": [ 121 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 122 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 123 | ] 124 | }, 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "\n" 130 | ] 131 | }, 132 | { 133 | "name": "stderr", 134 | "output_type": "stream", 135 | "text": [ 136 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 137 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 138 | ] 139 | }, 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "\n", 145 | "$42.50$ & $11.95$ & $4.37$ & $2.09$ & $21.49$ & $86.30$ & $61.36$ & $45.70$ & $214.40$\n", 146 | "$41.99$ & $12.83$ & $5.22$ & $2.26$ & $22.67$ & $86.95$ & $62.35$ & $43.55$ & $138.70$\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "results = []\n", 152 | "for i , f in enumerate(files):\n", 153 | " preds = [e[\"summary\"] for e in f]\n", 154 | " metrics = compute_custom_metrics(srcs, golds, preds, \"cuda\")\n", 155 | " metrics = {k: round(v, 2) for k, v in metrics.items()}\n", 156 | " results.append(get_metrics_as_latex(metrics))\n", 157 | " print()\n", 158 | " \n", 159 | "print('\\n'.join(results))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "# Print examples along with summaries\n", 169 | "\n", 170 | "for i in range(0, len(files[0])):\n", 171 | " print(f\"Example {i+1}\")\n", 172 | " print(f\"Source: {srcs[i]}\\n\")\n", 173 | " print(f\"Gold: {golds[i]}\\n\")\n", 174 | " for j, f in enumerate(files):\n", 175 | " print(f\"Summary {j+1}: {' '.join(f[i]['summary'].split())}\\n\")\n", 176 | " print()" 177 | ] 178 | } 179 | ], 180 | "metadata": { 181 | "kernelspec": { 182 | "display_name": "avs_gen", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.9.18" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /gpt-4/run_summarization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import datasets 8 | import fire 9 | import guidance 10 | import yaml 11 | from tqdm import tqdm 12 | 13 | ALL_PROMPTS = { 14 | "prompt_1": """ 15 | {{#system~}} 16 | You are a helpful assistant. 17 | {{~/system}} 18 | 19 | {{#user~}} 20 | You will be given a doctor's note and you will need to summarize the patient's brief hospital course. 21 | 22 | Let's do a practice round. 23 | {{~/user}} 24 | 25 | {{#assistant~}} 26 | Sounds great! 27 | {{~/assistant}} 28 | 29 | {{#each icl_examples}} 30 | {{#user}}Here is the doctor's note on a patient's brief hospital course: 31 | 32 | {{this.text}} 33 | 34 | Summarize for the patient what happened during the hospital stay based on this doctor's note. Please make it short and concise and only include key events and findings. 35 | {{/user}} 36 | {{#assistant}} 37 | {{this.summary}} 38 | {{/assistant}} 39 | {{/each}} 40 | 41 | 42 | {{#user~}} 43 | Here is the doctor's note on a patient's brief hospital course: 44 | 45 | {{final_text}} 46 | 47 | Summarize for the patient what happened during the hospital stay based on this doctor's note. Please make it short and concise and only include key events and findings. 48 | {{~/user}} 49 | 50 | {{#assistant~}} 51 | {{gen 'summary' max_tokens=600 temperature=0}} 52 | {{~/assistant}} 53 | """, 54 | "prompt_2": """ 55 | {{#system~}} 56 | You are helping with a resident working at a large urban academic medical center. 57 | {{~/system}} 58 | 59 | {{#user~}} 60 | You task is to help summarize a patient's brief hospital course based on the doctor's note. Please make it short and concise and only include key events and findings. 61 | 62 | Here are some examples: 63 | 64 | {{#each icl_examples}} 65 | DOCUMENT: 66 | {{this.text}} 67 | 68 | SUMMARY: 69 | {{this.summary}} 70 | {{/each}} 71 | 72 | Here is another doctor note on a patient's brief hospital course: 73 | 74 | DOCUMENT: {{final_text}} 75 | {{~/user}} 76 | 77 | {{#assistant~}} 78 | {{gen 'summary' max_tokens=600 temperature=0}} 79 | {{~/assistant}} 80 | """, 81 | "prompt_3": """ 82 | {{#system~}} 83 | You are a helpful assistant that helps patients understand their medical records. 84 | {{~/system}} 85 | 86 | {{#user~}} 87 | You will be given some doctor's notes and you will need to summarize the patient's brief hospital course in one paragraph. Please only include key events and findings and avoid using medical jargons, and you MUST start the summary with "You were admitted". 88 | 89 | {{#if icl_examples}} 90 | Here are some examples: 91 | 92 | {{#each icl_examples}} 93 | DOCUMENT: 94 | {{this.text}} 95 | 96 | SUMMARY: 97 | {{this.summary}} 98 | {{/each}} 99 | {{/if}} 100 | 101 | DOCUMENT: {{final_text}} 102 | {{~/user}} 103 | 104 | {{#assistant~}} 105 | {{gen 'summary' max_tokens=600 temperature=0}} 106 | {{~/assistant}} 107 | """, 108 | "prompt_3.1": """ 109 | {{#system~}} 110 | You are a helpful assistant that helps patients understand their medical records. 111 | {{~/system}} 112 | 113 | {{#user~}} 114 | You will be given some doctor's notes and you will need to summarize the patient's brief hospital course in ONE paragraph with a few sentences. Please only include key events and findings and avoid using medical jargons, and you MUST start the summary with "You were admitted". 115 | 116 | DOCUMENT: {{final_text}} 117 | {{~/user}} 118 | 119 | {{#assistant~}} 120 | {{gen 'summary' max_tokens=600 temperature=0}} 121 | {{~/assistant}} 122 | """, 123 | } 124 | 125 | 126 | def read_jsonl(file_name): 127 | with open(file_name, "r") as f: 128 | return [json.loads(line) for line in f] 129 | 130 | 131 | def write_jsonl(file_name, data): 132 | with open(file_name, "w") as f: 133 | for line in data: 134 | f.write(json.dumps(line) + "\n") 135 | 136 | 137 | def load_oai_model(model_name, max_calls_per_min=60): 138 | with open("config.yaml", "r") as f: 139 | config = yaml.safe_load(f) 140 | 141 | common_kwargs = { 142 | "max_calls_per_min": max_calls_per_min, # Maximum number of calls that can be made per minute (default is 60) 143 | } 144 | 145 | assert model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo-16k"] 146 | 147 | if config["openai_api_mode"] == "openai": 148 | os.environ["OPENAI_API_KEY"] = config["openai_api_key"] 149 | model = guidance.llms.OpenAI( 150 | model_name, **common_kwargs, organization=config["openai_organization"] 151 | ) 152 | elif config["openai_api_mode"] == "azure": 153 | deployment_id = model_name 154 | if model_name == "gpt-3.5-turbo": 155 | deployment_id = "gpt-35-turbo" 156 | elif model_name == "gpt-3.5-turbo-16k": 157 | deployment_id = "gpt-35-turbo-16k" 158 | 159 | model = guidance.llms.OpenAI( 160 | model_name, 161 | api_type="azure", 162 | api_key=config["azure_api_key"], 163 | api_base=config["azure_api_base"], 164 | api_version=config["azure_api_version"], 165 | deployment_id=deployment_id, 166 | **common_kwargs, 167 | ) 168 | return model 169 | 170 | 171 | def run_summarization( 172 | task_id: int, 173 | prompt_id: int, 174 | model_name: str, 175 | n_shot: int, 176 | save_path: Optional[str] = None, 177 | what_for: str = "exp", 178 | verbose: bool = False, 179 | debug: bool = False, 180 | ): 181 | demonstrations = read_jsonl( 182 | f"summarization_data/{what_for}_{task_id}_in-context.json" 183 | ) 184 | test_examples = read_jsonl(f"summarization_data/{what_for}_{task_id}_test.json") 185 | 186 | bad_demonstration_ids = [] 187 | for i, demonstration in enumerate(demonstrations): 188 | if demonstration["summary"].startswith("He came to the"): 189 | bad_demonstration_ids.append(i) 190 | 191 | assert len(demonstrations) >= n_shot 192 | if n_shot < len(demonstrations): 193 | random.seed(32) 194 | indices = list(range(len(demonstrations))) 195 | random.shuffle(indices) 196 | indices = [i for i in indices if i not in bad_demonstration_ids] 197 | icl_examples = [demonstrations[i] for i in indices[:n_shot]] 198 | else: 199 | icl_examples = demonstrations 200 | 201 | llm = load_oai_model(model_name) 202 | 203 | used_prompt = ALL_PROMPTS[f"prompt_{prompt_id}"] 204 | summarization_program_nshot = guidance(used_prompt) 205 | 206 | if verbose: 207 | print(f"Using {len(icl_examples)} ICL examples") 208 | print(icl_examples) 209 | 210 | if debug: 211 | test_examples = test_examples[:10] 212 | 213 | if save_path is None: 214 | save_path = f"summarization_results/{model_name}_{what_for}{task_id}_results_prompt{prompt_id}_{n_shot}shot.jsonl" 215 | Path(save_path).parent.mkdir(parents=True, exist_ok=True) 216 | write_jsonl(save_path.replace(".jsonl", "_icl.jsonl"), icl_examples) 217 | 218 | with open(save_path.replace(".jsonl", "_prompt.txt"), "w") as f: 219 | f.write(used_prompt) 220 | 221 | failure_indices = [] 222 | all_results = [] 223 | 224 | for example_idx in tqdm(range(len(test_examples))): 225 | example = test_examples[example_idx] 226 | 227 | gen_answer = summarization_program_nshot( 228 | icl_examples=icl_examples, 229 | final_text=example["text"], 230 | llm=llm, 231 | verbose=verbose, 232 | ) 233 | 234 | try: 235 | summary = gen_answer["summary"] 236 | except: 237 | print(f"Failed to generate answer for example {example_idx}") 238 | summary = "" 239 | failure_indices.append(example_idx) 240 | 241 | all_results.append( 242 | { 243 | "index": example_idx, 244 | "question": example["text"], 245 | "summary": summary, 246 | } 247 | ) 248 | if verbose: 249 | print(f"Text: {example['text']}") 250 | print(f"Summary: {summary}") 251 | print("=====================================") 252 | 253 | write_jsonl(save_path, all_results) 254 | 255 | with open(save_path.replace(".jsonl", "_failures.json"), "w") as f: 256 | json.dump(failure_indices, f, indent=2) 257 | 258 | 259 | if __name__ == "__main__": 260 | fire.Fire(run_summarization) 261 | -------------------------------------------------------------------------------- /notebooks/average_test_runs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# For a given set of test metrics, determine mean and standard deviation and round to 2 decimal places" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "## Llama 7B - short data\n", 26 | "results = {\n", 27 | " 'Llama 7B - short data': [\n", 28 | " {'rouge1': 38.18726652084839, 'rouge2': 12.522356326993874, 'rouge3': 5.300379506691891, 'rouge4': 2.6844528863937667, 'rougeL': 23.425086591887453, 'words': 104.15, 'bert_score': 84.68257343769073, 'bert_score_deberta-large': 58.7152236700058, 'sari': 43.39504555237768},\n", 29 | " {'rouge1': 36.55885611308579, 'rouge2': 11.770071884328013, 'rouge3': 4.96413117973415, 'rouge4': 2.3760043053086815, 'rougeL': 22.520381900800416, 'words': 100.56, 'bert_score': 82.0653041601181, 'bert_score_deberta-large': 56.56696653366089, 'sari': 41.94318296286558},\n", 30 | " {'rouge1': 36.865581109022166, 'rouge2': 12.035793003259116, 'rouge3': 5.477773544226487, 'rouge4': 2.831216527519416, 'rougeL': 22.85251778524732, 'words': 94.03, 'bert_score': 81.28817218542099, 'bert_score_deberta-large': 56.447242498397834, 'sari': 42.631882102051144},\n", 31 | " {'rouge1': 36.45585401833812, 'rouge2': 11.507301855132011, 'rouge3': 4.979182106965462, 'rouge4': 2.525371873858721, 'rougeL': 22.339373505129323, 'words': 100.47, 'bert_score': 82.16391408443451, 'bert_score_deberta-large': 56.96498113870621, 'sari': 42.04524834049528},\n", 32 | " {'rouge1': 36.698585402455606, 'rouge2': 11.78729495150544, 'rouge3': 4.868574815166207, 'rouge4': 2.2399180637599803, 'rougeL': 22.530625996895047, 'words': 103.47, 'bert_score': 81.9952797293663, 'bert_score_deberta-large': 56.640745639801025, 'sari': 42.02560221098953}\n", 33 | " ],\n", 34 | " 'Llama 7B - long data': [\n", 35 | " {'rouge1': 38.8343594182866, 'rouge2': 12.97963450275313, 'rouge3': 5.273994688087981, 'rouge4': 2.388816250393242, 'rougeL': 24.811244853956904, 'words': 71.92, 'bert_score': 86.49026721715927, 'bert_score_deberta-large': 60.77595293521881, 'sari': 44.27485176179068},\n", 36 | " {'rouge1': 39.47507417776543, 'rouge2': 13.508055870992669, 'rouge3': 5.474262681411284, 'rouge4': 2.4299066163652707, 'rougeL': 25.201416964411987, 'words': 77.79, 'bert_score': 84.75045335292816, 'bert_score_deberta-large': 60.05594950914382, 'sari': 44.487989828395825},\n", 37 | " {'rouge1': 38.167512768514875, 'rouge2': 12.432498880471945, 'rouge3': 5.166171874893524, 'rouge4': 2.2279453931573583, 'rougeL': 24.912878164301006, 'words': 69.49, 'bert_score': 86.43438655138016, 'bert_score_deberta-large': 60.743097960948944, 'sari': 44.387922498396605},\n", 38 | " {'rouge1': 38.474542505246575, 'rouge2': 12.59729023685818, 'rouge3': 5.032502045215025, 'rouge4': 2.08290036294411, 'rougeL': 24.771698689284257, 'words': 71.22, 'bert_score': 86.25028610229492, 'bert_score_deberta-large': 60.663696229457855, 'sari': 43.91790431156874},\n", 39 | " {'rouge1': 36.83328812388325, 'rouge2': 11.758537910375646, 'rouge3': 4.723083386770241, 'rouge4': 2.0583976340860497, 'rougeL': 23.944460740287273, 'words': 75.22, 'bert_score': 84.4628956913948, 'bert_score_deberta-large': 58.90815430879592, 'sari': 43.55259789375479}\n", 40 | " ],\n", 41 | " 'Llama 70B - short data': [\n", 42 | " {'rouge1': 42.1991438443541, 'rouge2': 13.563337998083213, 'rouge3': 5.7237180188916845, 'rouge4': 2.6722184607364134, 'rougeL': 24.777895343443596, 'words': 111.31, 'bert_score': 87.0364066362381, 'bert_score_deberta-large': 61.91603672504426, 'sari': 44.16604762927569},\n", 43 | " {'rouge1': 41.67500471009928, 'rouge2': 13.840969941914366, 'rouge3': 5.947293259680258, 'rouge4': 2.844910436528473, 'rougeL': 24.836899438706126, 'words': 120.9, 'bert_score': 85.90889036655426, 'bert_score_deberta-large': 60.57257956266403, 'sari': 44.07321356056379},\n", 44 | " {'rouge1': 41.88252074463912, 'rouge2': 14.074882768448884, 'rouge3': 6.039503368150073, 'rouge4': 2.8148856591135196, 'rougeL': 25.643550342550594, 'words': 112.34, 'bert_score': 87.158194065094, 'bert_score_deberta-large': 62.24099487066269, 'sari': 44.35337536946284},\n", 45 | " {'rouge1': 41.83728919015555, 'rouge2': 13.09889608467368, 'rouge3': 5.487954667140488, 'rouge4': 2.4686725964920826, 'rougeL': 24.435506275256007, 'words': 112.07, 'bert_score': 86.0190686583519, 'bert_score_deberta-large': 61.00947117805481, 'sari': 43.20317982433489},\n", 46 | " {'rouge1': 41.5120357352563, 'rouge2': 13.574013168402391, 'rouge3': 5.627399063893296, 'rouge4': 2.4955528332014443, 'rougeL': 24.480105189357122, 'words': 113.79, 'bert_score': 86.04067796468735, 'bert_score_deberta-large': 60.975450932979584, 'sari': 43.51100259877742}\n", 47 | " ],\n", 48 | " 'Llama 70B - long data': [\n", 49 | " {'rouge1': 40.5555217296232, 'rouge2': 14.380567615873911, 'rouge3': 6.145952549654063, 'rouge4': 2.9839005077717538, 'rougeL': 25.812246308651094, 'words': 76.46, 'bert_score': 86.00773721933365, 'bert_score_deberta-large': 61.7718161046505, 'sari': 45.00554637223204},\n", 50 | " {'rouge1': 40.54058585622222, 'rouge2': 14.166953530916869, 'rouge3': 5.695409156176871, 'rouge4': 2.3296289556276344, 'rougeL': 26.20942324367626, 'words': 80.28, 'bert_score': 86.76197403669357, 'bert_score_deberta-large': 62.13716307282448, 'sari': 44.9823358329503},\n", 51 | " {'rouge1': 40.34610201932814, 'rouge2': 14.411568668138443, 'rouge3': 6.339125512527877, 'rouge4': 2.953904152985041, 'rougeL': 26.32607735907873, 'words': 75.4, 'bert_score': 86.87447029352188, 'bert_score_deberta-large': 62.064355462789536, 'sari': 45.71918948150783},\n", 52 | " {'rouge1': 40.1893462851726, 'rouge2': 14.055680359041899, 'rouge3': 5.956359809720371, 'rouge4': 2.5952588267124463, 'rougeL': 25.947485210072838, 'words': 74.67, 'bert_score': 85.03679966926575, 'bert_score_deberta-large': 61.02019691467285, 'sari': 44.80799764604085},\n", 53 | " {'rouge1': 41.26608906368885, 'rouge2': 14.527248259553481, 'rouge3': 6.298794377100548, 'rouge4': 2.8508339445193314, 'rougeL': 26.633192382147357, 'words': 77.67, 'bert_score': 86.83461207151413, 'bert_score_deberta-large': 62.4643184542656, 'sari': 45.28758563082025}\n", 54 | " ]\n", 55 | "\n", 56 | "}" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "Llama 7B - short data & $36.95$ & $11.92$ & $5.12$ & $2.53$ & $22.73$ & $82.44$ & $57.07$ & $42.41$ & $100.54$ \\\\\n", 69 | "Llama 7B - long data & $38.36$ & $12.66$ & $5.13$ & $2.24$ & $24.73$ & $85.68$ & $60.23$ & $44.12$ & $73.13$ \\\\\n", 70 | "Llama 70B - short data & $41.82$ & $13.63$ & $5.77$ & $2.66$ & $24.83$ & $86.43$ & $61.34$ & $43.86$ & $114.08$ \\\\\n", 71 | "Llama 70B - long data & $40.58$ & $14.31$ & $6.09$ & $2.74$ & $26.19$ & $86.30$ & $61.89$ & $45.16$ & $76.90$ \\\\\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "# Put into dataframe\n", 77 | "with_std = False\n", 78 | "for model, model_results in results.items():\n", 79 | " \n", 80 | " df = pd.DataFrame(model_results)\n", 81 | " # Change order in dataframe columns\n", 82 | " df = df[['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL', 'bert_score', 'bert_score_deberta-large', 'sari', 'words']]\n", 83 | " # Assume first row: Model & R-1 & R-2 & R-3 & R-4 & R-L & BERTScore & Deberta & SARI & Words \\\\ \\midrule\n", 84 | " means = df.mean().round(2)\n", 85 | " stds = df.std().round(2)\n", 86 | " if with_std:\n", 87 | " print(model + \" & \" + \" & \".join([f\"${m:.2f}$ (${s:.2f}$)\" for m, s in zip(means, stds)]) + \" \\\\\\\\\")\n", 88 | " else:\n", 89 | " print(model + \" & \" + \" & \".join([f\"${m:.2f}$\" for m in means]) + \" \\\\\\\\\")\n" 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "avs_gen", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.9.18" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 2 114 | } 115 | -------------------------------------------------------------------------------- /gpt-4/create_hallucination_icl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "2906c69a-68f0-4b15-9a45-e34e7ca1f49b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from run_hallucination_detection import *\n", 11 | "demonstrations = read_jsonl(DATASET_PATHS[\"valid_mimic\"])" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "4f218439-db46-46ae-a305-561dabcaea06", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "icl_examples = []" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "6acb5cd5-a60e-43dd-a091-3d334e17ad26", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import re\n", 32 | "def remove_class_label(text):\n", 33 | " return re.sub(r' class=\"[^\"]*\"', '', text)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "fdd435ab-dfc7-4482-8ee3-9b9473004415", 39 | "metadata": {}, 40 | "source": [ 41 | "## V1 Format of the prompts " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "26378110-7128-4208-9acd-c4fcda89d953", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "ex = demonstrations[0]\n", 52 | "print(create_icl_example_v1(ex, add_hallucination_type=True))\n", 53 | "\n", 54 | "cot_description = \"\"\"\n", 55 | "- \"Your red blood cell count was followed and was stable.\" The BHC does not state that the red blood cell count was followed. Instead the hematocrit remained stable according to the BHC.\n", 56 | "- \"You were treated with 2 days of antibiotics which were stopped prior to discharge.\" There is no clear time interval for antibiotic treatment in the BHC.\n", 57 | "\"\"\"\n", 58 | "\n", 59 | "cot_no_label = remove_class_label(cot_description)\n", 60 | "ex['cot_description'] = cot_no_label.strip()\n", 61 | "ex['cot_description_with_label'] = cot_description.strip()\n", 62 | "\n", 63 | "print(cot_no_label)\n", 64 | "icl_examples.append(ex)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "94fdd3fe-ceff-4240-a0da-d1e20a5d4dc8", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "ex = demonstrations[1]\n", 75 | "print(create_icl_example_v1(ex, add_hallucination_type=True))\n", 76 | "\n", 77 | "cot_description = \"\"\"\n", 78 | "- \"You were also given blood because you lost a fair amount in your stool.\" The BHC does not state that the patient received blood.\n", 79 | "- \"Please hold off from taking your Isosorbide mononitrate (Imdur) and losartan until you meet with your primary care physician within the week.\" The BHC includes the information that the patient should meet with his primary care physician within one week (\"1 week\") which is different to within the week, which only includes the remainder of the current week.\n", 80 | "- \"Also, hold from taking your torsemide (unless you notice significant weight gain in the next few days) until you meet with your primary care physician within the week\" There are no specific instructions in the BHC stating the the patient should start the Torsemid by himself; and the BHC includes the information that the patient should meet with his primary care physician within one week (\"1 week\") which is different to within the week, which only includes the remainder of the current week.\n", 81 | "\"\"\"\n", 82 | "cot_no_label = remove_class_label(cot_description)\n", 83 | "ex['cot_description'] = cot_no_label.strip()\n", 84 | "ex['cot_description_with_label'] = cot_description.strip()\n", 85 | "\n", 86 | "print(cot_no_label)\n", 87 | "icl_examples.append(ex)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "74996cb5-a186-4609-8d26-0c8a7b041778", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "ex = demonstrations[2]\n", 98 | "print(create_icl_example_v1(ex, add_hallucination_type=True))\n", 99 | "\n", 100 | "cot_description = \"\"\"\n", 101 | "No errors detected in the AVS based on the provided BHC.\n", 102 | "\"\"\"\n", 103 | "\n", 104 | "cot_no_label = remove_class_label(cot_description)\n", 105 | "ex['cot_description'] = cot_no_label.strip()\n", 106 | "ex['cot_description_with_label'] = cot_description.strip()\n", 107 | "\n", 108 | "print(cot_no_label)\n", 109 | "icl_examples.append(ex)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "c305c928-d21c-4883-8cce-b6b8906e1448", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "ex = demonstrations[3]\n", 120 | "print(create_icl_example_v1(ex, add_hallucination_type=True))\n", 121 | "\n", 122 | "cot_description = \"\"\"\n", 123 | "- \"We monitored you in the ___ and on the medical floors.\" There is no information in the BHC that the patient was monitored on the medical floors. The patient could have been monitored somewhere else in the hospital.\n", 124 | "- \"You remained stable and you received Valium as needed for withdrawl symptoms.\" The BHC states that the patient did not require Valium.\n", 125 | "- \"These can be done as an outpatient and we have placed orders in the computer for you to have them done.\" Unclear whether the orders for an MRI and EEG were already placed in the computer.\n", 126 | "- \"They also recommended seeing a neurologist as an outpatient and your primary care provider is aware and ___ help set up the tests and appointment with neurology.\" The neurologist in the hospital recommend doing an MRI of the head and an EEG as an outpatient, which is already stated in the Discharge Instructions. However, they do not specifically recommend seeing a neurologist as an outpatient; The BHC does not state that the primary care provider is aware of the recommendations.\n", 127 | "\"\"\"\n", 128 | "\n", 129 | "cot_no_label = remove_class_label(cot_description)\n", 130 | "ex['cot_description'] = cot_no_label.strip()\n", 131 | "ex['cot_description_with_label'] = cot_description.strip()\n", 132 | "\n", 133 | "print(cot_no_label)\n", 134 | "icl_examples.append(ex)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "c63a7a30-adf9-4ca8-8f90-2ca047bfb9ee", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "ex = demonstrations[4]\n", 145 | "print(create_icl_example_v1(ex, add_hallucination_type=True))\n", 146 | "\n", 147 | "cot_description = \"\"\"\n", 148 | "- \"You can take an anti-nausea medicine called zofran as needed.\" The patient received Zofran in the hospital, but Zofran is not presribed as needed for own use.\n", 149 | "- \"You were seen ___ the hospital by the nutritionist, who recommended that you take a nutritional supplement with each meal, such as Boost or Carnation.\" Nutrional supplements are recommended three times a day, but it is not stated that they should be taken with each meal.\n", 150 | "\"\"\"\n", 151 | "\n", 152 | "cot_no_label = remove_class_label(cot_description)\n", 153 | "ex['cot_description'] = cot_no_label.strip()\n", 154 | "ex['cot_description_with_label'] = cot_description.strip()\n", 155 | "\n", 156 | "print(cot_no_label)\n", 157 | "icl_examples.append(ex)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "910a8f71-c928-4359-8e16-6f6ac6cfd90e", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "write_jsonl(\"hallucination_detection_data/icl_v1.jsonl\", icl_examples)" 168 | ] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Python 3 (ipykernel)", 174 | "language": "python", 175 | "name": "python3" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.10.8" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 5 192 | } 193 | -------------------------------------------------------------------------------- /gpt-4/summarization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GPT-4 Summarization\n", 8 | "\n", 9 | "## Creating a Prompt\n", 10 | "\n", 11 | "For creating a prompt, I will give you 10 training examples of the original text-summary pairs and 10 validation examples.\n", 12 | "I will also provide code to check the performance of the 10 validation examples below.\n", 13 | "You have to imput the output of GPT-4 for these by hand.\n", 14 | "Not all experiments use exactly the same data as the original text-summary pairs (see below), but I think these are good to get a sense of the performance and create a prompt for all experiments.\n", 15 | "\n", 16 | "## Experiments To Run\n", 17 | "\n", 18 | "All other experiments come with their own 10 in-context examples.\n", 19 | "\n", 20 | "### For quantitative performance estimates\n", 21 | "\n", 22 | "1. Summarization of 100 original text-summary pairs\n", 23 | "2. Summarization of 100 original text-summary pairs with short text (<4000 chars) and long summaries (>600 chars)\n", 24 | " * I did not mention this to you, but we also have to get the performance on this data.\n", 25 | " * This is a subset of 20% of the data I had to work with to make the human annotation feasible. Too long texts where impossible to annotate.\n", 26 | " * Basically I just want to show that this subselection makes no difference in performance.\n", 27 | "3. Not high priority, but could be useful: Summarization of 100 _cleaned and improved_ text-summary pairs when using 10 cleaned and improved in-context examples (10 validation _cleaned and improved data_)\n", 28 | "\n", 29 | "### For annotating hallucinations and determining hallucination rates\n", 30 | "\n", 31 | "4. Summarization of 25 examples when using in-context examples with unsupported facts (10 validation _original data_)\n", 32 | " * I will give you 50 test examples to have some for debugging\n", 33 | "5. Summarization of 25 examples when using in-context examples with unsupported facts removed (10 validation _cleaned data_)\n", 34 | " * I will give you 50 test examples to have some for debugging\n", 35 | "\n", 36 | "### For qualitative results with human annotation\n", 37 | "\n", 38 | "6. Summarization of 25 examples when using in-context examples with unsupported facts removed and improved text such as deidentification removed (10 validation _cleaned and improved data_)\n", 39 | " * I will give you 50 test examples to have some for debugging" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Imports\n", 49 | "import json\n", 50 | "import random\n", 51 | "import numpy as np\n", 52 | "from collections import defaultdict\n", 53 | "import evaluate\n", 54 | "from rouge_score import rouge_scorer" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 9, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# Read all files\n", 64 | "def read_jsonl(file_name):\n", 65 | " with open(file_name, \"r\") as f:\n", 66 | " return [json.loads(line) for line in f]\n", 67 | " \n", 68 | "prompt_train = read_jsonl('summarization_data/prompt_train.json')\n", 69 | "prompt_valid = read_jsonl('summarization_data/prompt_valid.json')\n", 70 | "\n", 71 | "exp_1_in_context = read_jsonl('summarization_data/exp_1_in-context.json')\n", 72 | "exp_1_test = read_jsonl('summarization_data/exp_1_test.json')\n", 73 | "exp_2_in_context = read_jsonl('summarization_data/exp_2_in-context.json')\n", 74 | "exp_2_test = read_jsonl('summarization_data/exp_2_test.json')\n", 75 | "exp_3_in_context = read_jsonl('summarization_data/exp_3_in-context.json')\n", 76 | "exp_3_test = read_jsonl('summarization_data/exp_3_test.json')\n", 77 | "\n", 78 | "exp_4_in_context = read_jsonl('summarization_data/exp_4_in-context.json')\n", 79 | "exp_4_test = read_jsonl('summarization_data/exp_4_test.json')\n", 80 | "exp_5_in_context = read_jsonl('summarization_data/exp_5_in-context.json')\n", 81 | "exp_5_test = read_jsonl('summarization_data/exp_5_test.json')\n", 82 | "\n", 83 | "exp_6_in_context = read_jsonl('summarization_data/exp_6_in-context.json')\n", 84 | "exp_6_test = read_jsonl('summarization_data/exp_6_test.json')\n", 85 | "\n", 86 | "assert len(prompt_train) == 10\n", 87 | "assert len(prompt_valid) == 10\n", 88 | "# Assert length of in-context always 10\n", 89 | "assert len(exp_1_in_context) == 10\n", 90 | "assert len(exp_2_in_context) == 10\n", 91 | "assert len(exp_3_in_context) == 10\n", 92 | "assert len(exp_4_in_context) == 10\n", 93 | "assert len(exp_5_in_context) == 10\n", 94 | "assert len(exp_6_in_context) == 10\n", 95 | "# Assert length of test\n", 96 | "assert len(exp_1_test) == 100\n", 97 | "assert len(exp_2_test) == 100\n", 98 | "assert len(exp_3_test) == 100\n", 99 | "assert len(exp_4_test) == 50\n", 100 | "assert len(exp_5_test) == 50\n", 101 | "assert len(exp_6_test) == 50" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 10, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Use custom rouge function to obtain rouge 3/4 which are not available in huggingface\n", 111 | "def get_rouge_score(gold, pred):\n", 112 | " rouge_scores = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL']\n", 113 | " scorer = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True)\n", 114 | " scores = scorer.score(gold, pred)\n", 115 | " return {k: scores[k].fmeasure * 100 for k in rouge_scores}\n", 116 | "\n", 117 | "def compute_custom_metrics(srcs, golds, preds, device):\n", 118 | " scores = defaultdict(list)\n", 119 | " bertscore = evaluate.load(\"bertscore\")\n", 120 | " sari = evaluate.load(\"sari\")\n", 121 | " \n", 122 | " # For rouge and length go over examples one by one and determine mean\n", 123 | " for gold, pred in zip(golds, preds):\n", 124 | " for k, v in get_rouge_score(gold, pred).items():\n", 125 | " scores[k].append(v)\n", 126 | " scores['words'].append(len(pred.split(' ')))\n", 127 | " for k, v in scores.items():\n", 128 | " scores[k] = np.mean(v)\n", 129 | "\n", 130 | " # This is the default call using model_type=\"roberta-large\"\n", 131 | " # This is the same as in the paper \"Generation of Patient After-Visit Summaries to Support Physicians\" (AVS_gen/eval_summarization.py) using the libary SummerTime\n", 132 | " scores['bert_score'] = np.mean((bertscore.compute(predictions=preds, references=golds, lang=\"en\", device=device))['f1']) * 100\n", 133 | " # BERTScore authors recommend \"microsoft/deberta-large-mnli\" (https://github.com/Tiiiger/bert_score)\n", 134 | " scores['bert_score_deberta-large'] = np.mean((bertscore.compute(predictions=preds, references=golds, device=device, model_type=\"microsoft/deberta-large-mnli\"))['f1']) * 100\n", 135 | " scores['sari'] = sari.compute(sources=srcs, predictions=preds, references=[[g] for g in golds])['sari']\n", 136 | " # scores['sari'] = scores['sari'][0]\n", 137 | " # Importing readability for dallc score not working: https://pypi.org/project/py-readability-metrics/ \n", 138 | "\n", 139 | " return {k: round(v, 2) for k, v in scores.items()}" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 11, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Evaluate on 1 validation examples.\n" 152 | ] 153 | }, 154 | { 155 | "name": "stderr", 156 | "output_type": "stream", 157 | "text": [ 158 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 159 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 160 | ] 161 | }, 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "{'rouge1': 5.19,\n", 166 | " 'rouge2': 0.0,\n", 167 | " 'rouge3': 0.0,\n", 168 | " 'rouge4': 0.0,\n", 169 | " 'rougeL': 5.19,\n", 170 | " 'words': 5.0,\n", 171 | " 'bert_score': 83.34,\n", 172 | " 'bert_score_deberta-large': 43.23,\n", 173 | " 'sari': 50.95}" 174 | ] 175 | }, 176 | "execution_count": 11, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "# Creating prompt\n", 183 | "\n", 184 | "# To obtain the valid performance on the 10 validation examples\n", 185 | "prompt_valid_gpt_predicitions = []\n", 186 | "prompt_valid_gpt_predicitions.append(\"This is a test prediction.\")\n", 187 | "prompt_valid_gpt_predicitions.append(\"\")\n", 188 | "prompt_valid_gpt_predicitions.append(\"\")\n", 189 | "prompt_valid_gpt_predicitions.append(\"\")\n", 190 | "prompt_valid_gpt_predicitions.append(\"\")\n", 191 | "prompt_valid_gpt_predicitions.append(\"\")\n", 192 | "prompt_valid_gpt_predicitions.append(\"\")\n", 193 | "prompt_valid_gpt_predicitions.append(\"\")\n", 194 | "prompt_valid_gpt_predicitions.append(\"\")\n", 195 | "prompt_valid_gpt_predicitions.append(\"\")\n", 196 | "\n", 197 | "srcs = []\n", 198 | "golds = []\n", 199 | "preds = []\n", 200 | "for i, pred in enumerate(prompt_valid_gpt_predicitions):\n", 201 | " if pred != \"\":\n", 202 | " srcs.append(prompt_valid[i]['text'])\n", 203 | " golds.append(prompt_valid[i]['summary'])\n", 204 | " preds.append(pred)\n", 205 | " \n", 206 | "print(f\"Evaluate on {len(srcs)} validation examples.\")\n", 207 | "compute_custom_metrics(srcs, golds, preds, \"cuda\")\n", 208 | "\n", 209 | "# Model & R-1 & R-2 & R-3 & R-L & BERTScore & Deberta & SARI & Words \\\\ \\midrule\n", 210 | "# Llama 2 70B (100 training ex.) & 43 & 15 & 6 & 25 & 87 & 62 & 44.24 & 125 \\\\" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "avs_gen", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.9.18" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /hallucination_detection/create_medcat_entities_and_sapbert_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Script to Extract Medcat UMLS Entities and Determine Embeddings" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import json\n", 17 | "from medcat.cat import CAT\n", 18 | "import pandas as pd\n", 19 | "from pathlib import Path\n", 20 | "import numpy as np\n", 21 | "import torch\n", 22 | "from transformers import AutoTokenizer, AutoModel \n", 23 | "from tqdm import tqdm" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# Experiment name\n", 33 | "experiment_name = \"umls_large\"\n", 34 | "\n", 35 | "# Define files and parameters\n", 36 | "bioc_labelled_hallucinations_10_valid_mimic_summaries_path = '/home/s_hegs02/MedTator/13_agreed_label_silver_validation_examples/hallucinations_10_valid_mimic_agreed.jsonl'\n", 37 | "bioc_labelled_hallucinations_100_mimic_summaries_path = '/home/s_hegs02/MedTator/12_agreed_label_silver_examples/hallucinations_100_mimic_agreed.jsonl'\n", 38 | "# TODO: Replace with the agreed dataset\n", 39 | "bioc_labelled_hallucinations_100_generated_summaries = '/home/s_hegs02/MedTator/20_label_halus_qualitatative_annotator_1/hallucinations_100_generated_annotator_1.jsonl'\n", 40 | "dataset_paths = {'valid_mimic': bioc_labelled_hallucinations_10_valid_mimic_summaries_path, 'test_mimic': bioc_labelled_hallucinations_100_mimic_summaries_path, 'test_generated': bioc_labelled_hallucinations_100_generated_summaries}\n", 41 | "entities_output_path = \"/home/s_hegs02/mimic-iv-note-di-bhc/entities/\"\n", 42 | "\n", 43 | "\n", 44 | "# MedCat model\n", 45 | "# Small model: UMLS Small (A modelpack containing a subset of UMLS (disorders, symptoms, medications...). Trained on MIMIC-III)\n", 46 | "# cat_model_path = \"/home/s_hegs02/medcat/models/umls_sm_pt2ch.zip\"\n", 47 | "# Large model: UMLS Full. >4MM concepts trained self-supervsied on MIMIC-III. v2022AA of UMLS.\n", 48 | "cat_model_path = \"/home/s_hegs02/medcat/models/umls_self_train_model.zip\"\n", 49 | "num_cpus = 4\n", 50 | "\n", 51 | "# Semantic types of Griffin's \"What's in a Summary\" paper\n", 52 | "# Disorders, Chemicals & Drugs, Procedures semantic groups, Lab Results \n", 53 | "# See groups here: https://lhncbc.nlm.nih.gov/ii/tools/MetaMap/Docs/SemGroups_2018.txt\n", 54 | "filtered_semantic_types = [\n", 55 | " 'T020', 'T190', 'T049', 'T019', 'T047', 'T050', 'T033', 'T037', 'T048', 'T191', 'T046', 'T184',\n", 56 | " 'T116', 'T195', 'T123', 'T122', 'T103', 'T120', 'T104', 'T200', 'T196', 'T126', 'T131', 'T125', 'T129', 'T130', 'T197', 'T114', 'T109', 'T121', 'T192', 'T127',\n", 57 | " 'T060', 'T065', 'T058', 'T059', 'T063', 'T062', 'T061', \n", 58 | " 'T034'\n", 59 | " ]\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Load dataset\n", 69 | "def read_jsonl(path):\n", 70 | " input = []\n", 71 | " with open(path) as f:\n", 72 | " for line in f:\n", 73 | " input.append(json.loads(line))\n", 74 | " return input\n", 75 | "\n", 76 | "datasets = {k: read_jsonl(v) for k, v in dataset_paths.items()}\n", 77 | "\n", 78 | "# Verify that all labels are correctly located\n", 79 | "for dataset_name, dataset in datasets.items():\n", 80 | " for i, doc in enumerate(dataset):\n", 81 | " for label in doc['labels']:\n", 82 | " assert label['start'] >= 0 and label['end'] <= len(doc['summary']), f\"Label {label} in dataset {dataset_name} is out of bounds for text of length {len(doc['summary'])} in document {i}\"\n", 83 | " assert doc['summary'][label['start']:label['end']] == label['text'], f\"Label {label} in dataset {dataset_name} does not match text in document {i}\"" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Extract entities for all concepts in the texts and summaries\n", 93 | "\n", 94 | "# Load medcat model\n", 95 | "cat = CAT.load_model_pack(cat_model_path)\n", 96 | "\n", 97 | "# Get entities for all texts and summaries in the datasets\n", 98 | "for dataset_name, dataset in datasets.items():\n", 99 | " output_file = Path(entities_output_path) / f\"medcat_entities_{dataset_name}_{experiment_name}.json\"\n", 100 | "\n", 101 | " if output_file.exists():\n", 102 | " print(f\"File {output_file} already exists. Skipping...\")\n", 103 | " continue\n", 104 | " \n", 105 | " print(f\"Extracting medcat entities for dataset {dataset_name}...\")\n", 106 | "\n", 107 | " # Load input json as pandas dataframe\n", 108 | " df = pd.DataFrame(dataset)[['text', 'summary']]\n", 109 | " assert df.notnull().values.all()\n", 110 | "\n", 111 | " # Prepare input data to MedCat by extracting all texts into one list\n", 112 | " i = 0\n", 113 | " in_data = []\n", 114 | " for col in df.columns:\n", 115 | " for _, text in enumerate(df[col].values):\n", 116 | " # Extract entities\n", 117 | " in_data.append((i, text))\n", 118 | " i += 1\n", 119 | "\n", 120 | " # Perform concept extraction\n", 121 | " # out_data is a dictionary for all input texts including a dictionay \"entities\" with all extracted entities for this text\n", 122 | " out_data = cat.multiprocessing(in_data, nproc=num_cpus)\n", 123 | " print(f'Total number of entities extracted: {sum([len(text[\"entities\"]) for text in out_data.values()])}')\n", 124 | "\n", 125 | " # Count occurrences of semantic types in semantic_types in the extracted entities\n", 126 | " semantic_types_counts = {}\n", 127 | " for text in out_data.values():\n", 128 | " for entity in text['entities'].values():\n", 129 | " for semantic_type in entity['type_ids']:\n", 130 | " if semantic_type in filtered_semantic_types:\n", 131 | " semantic_types_counts[semantic_type] = semantic_types_counts.get(semantic_type, 0) + 1\n", 132 | " print(f'Number of entities per inculded semantic type:')\n", 133 | " print({s: semantic_types_counts.get(s, 0) for s in filtered_semantic_types})\n", 134 | "\n", 135 | " # Filter out entities that are not in the semantic types\n", 136 | " for text in out_data.values():\n", 137 | " text['entities'] = {idx: entity for idx, entity in text['entities'].items() if any([s in entity['type_ids'] for s in filtered_semantic_types])}\n", 138 | " print(f'Total number of entities extracted after filtering: {sum([len(text[\"entities\"]) for text in out_data.values()])}')\n", 139 | "\n", 140 | " # Write back all extracted entities into the same format as the input\n", 141 | " i = 0\n", 142 | " for col in df.columns:\n", 143 | " for j, _ in enumerate(df[col].values):\n", 144 | " df[col][j] = [out_data[i]]\n", 145 | " i += 1\n", 146 | "\n", 147 | " # Save output to json\n", 148 | " df.to_json(output_file, orient='records', indent=4)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# Load SapBERT model\n", 158 | "tokenizer = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\") \n", 159 | "model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\").cuda()" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 6, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Extracting sapbert embeddings for dataset valid_mimic...\n", 172 | "Total number of entities: 519\n" 173 | ] 174 | }, 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | " 0%| | 0/9 [00:00= overlap_ratio and (not same_label or a['label'] == b['label']):\n", 141 | " agreement_list.append((overlap, a, b))\n", 142 | " return agreement_list\n", 143 | "\n", 144 | "def get_labels_no_agreement(labels, agreement, labeller):\n", 145 | " # Get the labels of a list of annotations that are not in agreement\n", 146 | " labels_not_in_agreement = []\n", 147 | " for label in labels:\n", 148 | " in_agreement = False\n", 149 | " for agreement_tuple in agreement:\n", 150 | " # agreement_tuple contains overlap ratio, annotation of labeller 1, annotation of labeller 2\n", 151 | " if label == agreement_tuple[labeller]:\n", 152 | " in_agreement = True\n", 153 | " break\n", 154 | " if not in_agreement:\n", 155 | " labels_not_in_agreement.append(label)\n", 156 | " return labels_not_in_agreement\n", 157 | "\n", 158 | "\n", 159 | "data['agreement_diff'] = data.apply(lambda row: get_agreement_list(row['labels_1'], row['labels_2']), axis=1)\n", 160 | "data['agreement_same'] = data.apply(lambda row: get_agreement_list(row['labels_1'], row['labels_2'], same_label=True), axis=1)\n", 161 | "data['labels_1_no_agreement_diff'] = data.apply(lambda row: get_labels_no_agreement(row['labels_1'], row['agreement_diff'], 1), axis=1)\n", 162 | "data['labels_2_no_agreement_diff'] = data.apply(lambda row: get_labels_no_agreement(row['labels_2'], row['agreement_diff'], 2), axis=1)\n", 163 | "data['labels_1_no_agreement_same'] = data.apply(lambda row: get_labels_no_agreement(row['labels_1'], row['agreement_same'], 1), axis=1)\n", 164 | "data['labels_2_no_agreement_same'] = data.apply(lambda row: get_labels_no_agreement(row['labels_2'], row['agreement_same'], 2), axis=1)\n", 165 | "\n", 166 | "# Check for labeller 1 and labeller 2 that number of labels in agreement and not in agreement are the same as the total number of labels\n", 167 | "assert data.apply(lambda row: len(row['labels_1']) == len(row['agreement_diff']) + len(row['labels_1_no_agreement_diff']), axis=1).all()\n", 168 | "assert data.apply(lambda row: len(row['labels_2']) == len(row['agreement_diff']) + len(row['labels_2_no_agreement_diff']), axis=1).all()\n", 169 | "assert data.apply(lambda row: len(row['labels_1']) == len(row['agreement_same']) + len(row['labels_1_no_agreement_same']), axis=1).all()\n", 170 | "assert data.apply(lambda row: len(row['labels_2']) == len(row['agreement_same']) + len(row['labels_2_no_agreement_same']), axis=1).all()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 56, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "Included 12 documents ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12])\n", 183 | "Total labels rater 1: 24\n", 184 | "Total labels rater 2: 8\n", 185 | "Total labels in 0.8 agreement w/ diff labels: 3 (12.50%, 37.50%)\n", 186 | "Total labels in 0.8 agreement w/ same labels: 2 (8.33%, 25.00%)\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "# Print general statistics\n", 192 | "total_labels_1 = data['labels_1'].apply(len).sum()\n", 193 | "total_labels_2 = data['labels_2'].apply(len).sum()\n", 194 | "total_agreement_diff = data['agreement_diff'].apply(len).sum()\n", 195 | "total_agreement_same = data['agreement_same'].apply(len).sum()\n", 196 | "\n", 197 | "print(f\"Included {len(included_ids)} documents ({included_ids})\")\n", 198 | "print(f\"Total labels rater 1: {total_labels_1}\")\n", 199 | "print(f\"Total labels rater 2: {total_labels_2}\")\n", 200 | "print(f\"Total labels in {overlap_ratio} agreement w/ diff labels: {total_agreement_diff} ({total_agreement_diff / total_labels_1 * 100:.2f}%, {total_agreement_diff / total_labels_2 * 100:.2f}%)\")\n", 201 | "print(f\"Total labels in {overlap_ratio} agreement w/ same labels: {total_agreement_same} ({total_agreement_same / total_labels_1 * 100:.2f}%, {total_agreement_same / total_labels_2 * 100:.2f}%)\")" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# Print statistics per document with the following format:\n", 211 | "# document id\n", 212 | "# Rater 1: total labels rater 1 (total labels in agreement w/ diff labels, total labels in agreement w/ same labels)\n", 213 | "# Rater 2: total labels rater 2 (total labels in agreement w/ diff labels, total labels in agreement w/ same labels)\n", 214 | "# Both Raters: list of labels annotated by both raters with different labels\n", 215 | "# Only Rater 1: list of labels only annotated by rater 1 with different labels\n", 216 | "# Only Rater 2: list of labels only annotated by rater 2 with different labels\n", 217 | "\n", 218 | "def format_label(label):\n", 219 | " # Label is a dict with keys start, end, length, label, text\n", 220 | " return f\"{label['text']} {label['start']}-{label['end']} ({label['label']})\"\n", 221 | "\n", 222 | "def format_labels(labels):\n", 223 | " # Format as numbered list\n", 224 | " return '\\n'.join([f\"\\t{i+1}. {format_label(l)}\" for i, l in enumerate(labels)])\n", 225 | "\n", 226 | "def format_agreement(agreement):\n", 227 | " # agreement is a tuple with overlap ratio, annotation of labeller 1, annotation of labeller 2\n", 228 | " text = f\"{agreement[0]:.2f} {agreement[1]['text']} vs. {agreement[2]['text']} {agreement[1]['start']}-{agreement[1]['end']}/{agreement[2]['start']}-{agreement[2]['end']}\"\n", 229 | " if agreement[1]['label'] != agreement[2]['label']:\n", 230 | " text += f\" ({agreement[1]['label']} vs. {agreement[2]['label']})\"\n", 231 | " return text\n", 232 | "\n", 233 | "def format_agreements(agreements):\n", 234 | " # Format as numbered list\n", 235 | " return ('\\n' if len(agreements) > 0 else '') + '\\n'.join([f\"\\t{i+1}. {format_agreement(a)}\" for i, a in enumerate(agreements)])\n", 236 | "\n", 237 | "for index, row in data.iterrows():\n", 238 | " print(f\"Document {row['id']}\")\n", 239 | " print(f\" Rater 1: {len(row['labels_1'])}\")\n", 240 | " print(f\" Rater 2: {len(row['labels_2'])}\")\n", 241 | " print(f\" Both Raters ({len(row['agreement_diff'])}):{format_agreements(row['agreement_diff'])}\")\n", 242 | " print(f\" Only Rater 1 ({len(row['labels_1_no_agreement_diff'])}):{format_labels(row['labels_1_no_agreement_diff'])}\")\n", 243 | " print(f\" Only Rater 2 ({len(row['labels_2_no_agreement_diff'])}):{format_labels(row['labels_2_no_agreement_diff'])}\")\n", 244 | " print()\n", 245 | " " 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 42, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "001 resection of an inguinal mass procedure_unsupported condition_unsupported\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "# Print labels for each document that have an overlap of 80% or more\n", 263 | "for document_id in labeling_1_dict:\n", 264 | " if document_id in labeling_2_dict:\n", 265 | " # Get annotations for document\n", 266 | " annotations_1 = labeling_1_dict[document_id]\n", 267 | " annotations_2 = labeling_2_dict[document_id]\n", 268 | " # Get all annotations that overlap 80% or more\n", 269 | " overlapping_annotations = []\n", 270 | " for a1 in annotations_1:\n", 271 | " for a2 in annotations_2:\n", 272 | " if a1['start'] >= a2['start'] and a1['start'] <= a2['end']:\n", 273 | " overlapping_annotations.append((a1, a2))\n", 274 | " break\n", 275 | " elif a2['start'] >= a1['start'] and a2['start'] <= a1['end']:\n", 276 | " overlapping_annotations.append((a1, a2))\n", 277 | " break\n", 278 | " # Print annotations\n", 279 | " for a1, a2 in overlapping_annotations:\n", 280 | " print(document_id, a1['text'], a1['label'], a2['label'])" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "avs_gen", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.9.18" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /hallucination_detection/convert_bioc_to_json_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Create JSON datasets from raw bioc labelings" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import json\n", 18 | "import re\n", 19 | "from utils import read_bioc, parse_text_labels\n", 20 | "from pathlib import Path" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# Define files and parameters\n", 30 | "data_path = '/home/s_hegs02/MedTator'\n", 31 | "data_path = Path(data_path)\n", 32 | "\n", 33 | "dataset_paths = {\n", 34 | " # Experiment 1: label mimic summaries\n", 35 | " 'hallucinations_100_mimic_annotator_1': data_path / '10_label_silver_examples_annotator_1' / 'hallucinations_100_mimic_annotator_1.xml',\n", 36 | " 'hallucinations_100_mimic_annotator_2': data_path / '11_label_silver_examples_annotator_2' / 'hallucinations_100_mimic_annotator_2.xml',\n", 37 | " 'hallucinations_100_mimic_agreed': data_path / '12_agreed_label_silver_examples' / 'hallucinations_100_mimic_agreed.xml',\n", 38 | " 'hallucinations_10_valid_mimic_agreed': data_path / '13_agreed_label_silver_validation_examples' / 'hallucinations_10_valid_mimic_agreed.xml',\n", 39 | " # Experiment 2: label generated summaries\n", 40 | " 'hallucinations_100_generated_annotator_1': data_path / '20_label_halus_qualitative_annotator_1' / 'hallucinations_100_generated_annotator_1.xml',\n", 41 | " 'hallucinations_100_generated_annotator_2': data_path / '21_label_halus_qualitative_annotator_2' / 'hallucinations_100_generated_annotator_2.xml',\n", 42 | " 'hallucinations_100_generated_agreed': data_path / '22_label_halus_qualitative_agreed' / 'hallucinations_100_generated_agreed.xml',\n", 43 | "}\n", 44 | "\n", 45 | "# Randomization for Experiment 2\n", 46 | "hallucination_random_models = {0: [0, 4, 3, 1, 2], 1: [4, 1, 3, 2, 0], 2: [0, 2, 1, 3, 4], 3: [1, 0, 3, 4, 2], 4: [3, 2, 4, 1, 0], 5: [3, 2, 1, 4, 0], 6: [1, 2, 3, 0, 4], 7: [1, 3, 2, 4, 0], 8: [4, 0, 1, 3, 2], 9: [0, 3, 2, 4, 1], 10: [0, 4, 3, 2, 1], 11: [1, 0, 4, 3, 2], 12: [2, 4, 1, 0, 3], 13: [3, 1, 0, 4, 2], 14: [4, 2, 0, 1, 3], 15: [0, 2, 4, 3, 1], 16: [1, 4, 2, 3, 0], 17: [2, 3, 1, 0, 4], 18: [4, 0, 3, 2, 1], 19: [0, 3, 1, 2, 4], 20: [4, 0, 2, 3, 1], 21: [0, 4, 2, 1, 3], 22: [0, 2, 4, 3, 1], 23: [1, 0, 3, 4, 2], 24: [3, 1, 0, 4, 2], 25: [2, 0, 3, 4, 1], 26: [4, 3, 0, 1, 2], 27: [3, 4, 2, 1, 0], 28: [4, 2, 3, 1, 0], 29: [4, 1, 3, 0, 2], 30: [2, 3, 0, 1, 4], 31: [4, 2, 0, 3, 1], 32: [3, 0, 2, 1, 4], 33: [2, 3, 4, 1, 0], 34: [4, 1, 3, 2, 0], 35: [0, 4, 1, 3, 2], 36: [4, 1, 3, 0, 2], 37: [3, 1, 0, 4, 2], 38: [3, 2, 4, 1, 0], 39: [1, 0, 3, 4, 2], 40: [4, 3, 0, 1, 2], 41: [2, 3, 4, 0, 1], 42: [2, 4, 3, 1, 0], 43: [4, 1, 2, 0, 3], 44: [0, 4, 3, 1, 2], 45: [3, 2, 0, 1, 4], 46: [2, 4, 0, 3, 1], 47: [2, 1, 0, 4, 3], 48: [4, 2, 3, 1, 0], 49: [3, 1, 4, 2, 0]}\n", 47 | "\n", 48 | "# Define markers\n", 49 | "re_text_start_mimic_old_key = re.compile('### JSON Key: text\\n', re.MULTILINE)\n", 50 | "re_summary_start_mimic_old_key = re.compile('### JSON Key: summary\\n', re.MULTILINE)\n", 51 | "re_text_start_mimic = re.compile('Text:\\n', re.MULTILINE)\n", 52 | "re_summary_start_mimic = re.compile('Summary:\\n', re.MULTILINE)\n", 53 | "re_text_start_generated = re.compile('Text:\\n', re.MULTILINE)\n", 54 | "re_summary_start_generated = re.compile('Summary \\d:\\n', re.MULTILINE)\n", 55 | "markers = {k: (re_text_start_mimic, re_summary_start_mimic) if 'mimic' in k else (re_text_start_generated, re_summary_start_generated) for k in dataset_paths.keys()}\n", 56 | "# This two medtator datasets still used the old key\n", 57 | "markers['hallucinations_100_mimic_annotator_1'] = (re_text_start_mimic_old_key, re_summary_start_mimic_old_key)\n", 58 | "markers['hallucinations_100_mimic_annotator_2'] = (re_text_start_mimic_old_key, re_summary_start_mimic_old_key)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "datasets_bioc = {k: read_bioc(v) for k, v in dataset_paths.items()}\n", 68 | "datasets_unprocessed = {k: parse_text_labels(v) for k, v in datasets_bioc.items()}\n", 69 | "datasets = {k: [] for k in datasets_unprocessed.keys()}\n", 70 | "\n", 71 | "# Print included ids\n", 72 | "for k, v in datasets_unprocessed.items():\n", 73 | " print(f\"{k}: {len(v)} examples\")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 9, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# The datasets still contain the text (BHC) and the summary (discharge instructions) and the label positions are based on both texts.\n", 83 | "# Additionally, the generated examples contain one text and 5 randomized generations\n", 84 | "# Must split this data and correct the label positions to be based on the summaries only" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Added 100 examples with 239 labels to hallucinations_100_mimic_annotator_1\n", 97 | "Added 100 examples with 282 labels to hallucinations_100_mimic_annotator_2\n", 98 | "Added 100 examples with 286 labels to hallucinations_100_mimic_agreed\n", 99 | "Added 10 examples with 23 labels to hallucinations_10_valid_mimic_agreed\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "# 1. Labeling: First split text and summaries in mimic examples and correct label positions\n", 105 | "\n", 106 | "for dataset_name in [k for k in datasets_unprocessed.keys() if 'mimic' in k]:\n", 107 | " # print(k)\n", 108 | " # Get all keys sorted\n", 109 | " sorted_keys = (list(datasets_unprocessed[dataset_name].keys()))\n", 110 | " sorted_keys.sort()\n", 111 | " for key in sorted_keys:\n", 112 | " # print(key)\n", 113 | " # print(example)\n", 114 | " re_text_start, re_summary_start = markers[dataset_name]\n", 115 | " example = datasets_unprocessed[dataset_name][key]\n", 116 | " text_start = re_text_start.search(example['text']).span()[1]\n", 117 | " text_end = re_summary_start.search(example['text']).span()[0]\n", 118 | " summary_start = re_summary_start.search(example['text']).span()[1]\n", 119 | "\n", 120 | " text = example['text'][text_start:text_end].strip()\n", 121 | " summary = example['text'][summary_start:].rstrip()\n", 122 | " assert len(summary.lstrip()) == len(summary)\n", 123 | " # Debug\n", 124 | " # print(text)\n", 125 | " # print(summary)\n", 126 | " \n", 127 | " label_offset = summary_start\n", 128 | " labels = []\n", 129 | " for label in example['labels']:\n", 130 | " new_label = label.copy()\n", 131 | " new_label['start'] -= label_offset\n", 132 | " new_label['end'] -= label_offset\n", 133 | " # print(label, new_label)\n", 134 | " # Verify correct label\n", 135 | " assert example['text'][label['start']:label['end']] == label['text']\n", 136 | " assert summary[new_label['start']:new_label['end']] == label['text']\n", 137 | " labels.append(new_label)\n", 138 | " \n", 139 | " datasets[dataset_name].append({'text': text, 'summary': summary, 'labels': labels})\n", 140 | " print(f\"Added {len(datasets[dataset_name])} examples with {sum([len(ex['labels']) for ex in datasets[dataset_name]])} labels to {dataset_name}\")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "Added 100 examples with 123 labels to hallucinations_100_generated_annotator_1\n", 153 | "Added 100 examples with 118 labels to hallucinations_100_generated_annotator_2\n", 154 | "Added 100 examples with 114 labels to hallucinations_100_generated_agreed\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "# 2. Labeling: Second de-randomize generated summaries and put them into separate text-summary\n", 160 | "\n", 161 | "for dataset_name in [k for k in datasets_unprocessed.keys() if 'generated' in k]:\n", 162 | "\n", 163 | " num_unfolded_generated_examples = 0\n", 164 | " unfolded_generated_examples = [[] for i in range(5)]\n", 165 | " for id, example in datasets_unprocessed[dataset_name].items():\n", 166 | " # Get text entry between re_text_start_generated and re_summary_start_generated\n", 167 | " source = example['text']\n", 168 | " labels = example['labels']\n", 169 | " text = source[re_text_start_generated.search(source).end():re_summary_start_generated.search(source).start()].strip()\n", 170 | " summaries_labels = []\n", 171 | " \n", 172 | " source_summaries_offset = re_summary_start_generated.search(source).start()\n", 173 | " summary_label_len = len('Summary X:\\n')\n", 174 | " source_summaries = source[source_summaries_offset:]\n", 175 | " # Get all positions of re_text_start_generated\n", 176 | " summary_start_positions = [m.start() for m in re_summary_start_generated.finditer(source_summaries)] + [len(source_summaries)]\n", 177 | " summary_lens = [summary_start_positions[i+1] - summary_start_positions[i] for i in range(5)]\n", 178 | " \n", 179 | " # print(summary_start_positions)\n", 180 | " # for i in range(5):\n", 181 | " # print('---' + source_summaries[summary_start_positions[i] + summary_label_len:summary_start_positions[i+1]] + '---')\n", 182 | " \n", 183 | " randomized_summaries_labels = []\n", 184 | " processed_labels = []\n", 185 | " for i in range(5):\n", 186 | " summary_content_start = source_summaries_offset + summary_start_positions[i] + summary_label_len\n", 187 | " summary_content_end = source_summaries_offset + summary_start_positions[i+1]\n", 188 | " summary = source[summary_content_start:summary_content_end]\n", 189 | " summaries_labels = []\n", 190 | "\n", 191 | " # Get all labels for this summary\n", 192 | " for label in labels:\n", 193 | " if label['start'] >= summary_content_start and label['end'] <= summary_content_end:\n", 194 | " # Verify labe;\n", 195 | " assert source[label['start']:label['end']] == label['text']\n", 196 | " # Copy label\n", 197 | " new_label = label.copy()\n", 198 | " # Correct the label position\n", 199 | " new_label['start'] = label['start'] - summary_content_start\n", 200 | " new_label['end'] = label['end'] - summary_content_start\n", 201 | " # Check label at correct position in extracted summary\n", 202 | " assert summary[new_label['start']:new_label['end']] == label['text']\n", 203 | " summaries_labels.append(new_label)\n", 204 | " processed_labels.append(label)\n", 205 | " randomized_summaries_labels.append({'summary': summary, 'labels': summaries_labels})\n", 206 | " \n", 207 | " # Check that all labels were processed\n", 208 | " assert processed_labels == labels\n", 209 | " assert sum([len(ex['labels']) for ex in randomized_summaries_labels]) == len(labels)\n", 210 | " # Check all cahracter of source were processed\n", 211 | " assert source_summaries_offset + sum([len(ex['summary']) for ex in randomized_summaries_labels]) + 5 * summary_label_len == len(source)\n", 212 | " # Now remove trailing whitespaces for summaries and chek no leading whitespaces\n", 213 | " for i in range(5):\n", 214 | " assert len(randomized_summaries_labels[i]['summary']) == len(randomized_summaries_labels[i]['summary'].lstrip())\n", 215 | " randomized_summaries_labels[i]['summary'] = randomized_summaries_labels[i]['summary'].rstrip()\n", 216 | " \n", 217 | " # De-randomize\n", 218 | " summaries_labels = [''] * 5\n", 219 | " for i in range(5):\n", 220 | " summaries_labels[hallucination_random_models[id][i]] = randomized_summaries_labels[i]\n", 221 | " assert [e != '' for e in summaries_labels].count(True) == 5\n", 222 | " \n", 223 | " # Debug:\n", 224 | " # for e in summaries_labels:\n", 225 | " # print(e['summary'])\n", 226 | " # print(e['labels'])\n", 227 | " # print('---')\n", 228 | " \n", 229 | " # Move examples with text-summary format into unpacked\n", 230 | " for i in range(5):\n", 231 | " unfolded_generated_examples[i].append({'text': text, 'summary': summaries_labels[i]['summary'], 'labels': summaries_labels[i]['labels']})\n", 232 | " num_unfolded_generated_examples += 1\n", 233 | " \n", 234 | " # Combine all lists into one\n", 235 | " assert num_unfolded_generated_examples == 5 * len(datasets_unprocessed[dataset_name])\n", 236 | " datasets[dataset_name] = unfolded_generated_examples[0] + unfolded_generated_examples[1] + unfolded_generated_examples[2] + unfolded_generated_examples[3] + unfolded_generated_examples[4]\n", 237 | " print(f\"Added {len(datasets[dataset_name])} examples with {sum([len(ex['labels']) for ex in datasets[dataset_name]])} labels to {dataset_name}\")" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 10, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "# Write out datasets as jsonl into same folders as original bioc files\n", 247 | "output_dir = Path('/home/s_hegs02/patient_summaries_with_llms')\n", 248 | "for dataset_name in datasets.keys():\n", 249 | " file_name = str(dataset_paths[dataset_name]).split('/')[-1]\n", 250 | " file_name = file_name.replace('.xml', '.jsonl')\n", 251 | " with open(output_dir / file_name, 'w') as f:\n", 252 | " for example in datasets[dataset_name]:\n", 253 | " f.write(json.dumps(example) + '\\n')" 254 | ] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "avs_gen", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.9.18" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 2 278 | } 279 | -------------------------------------------------------------------------------- /summarization/README.md: -------------------------------------------------------------------------------- 1 | # Summarization Models 2 | 3 | This folder contains the scripts to run and evanluate the summarization models LED and Llama 2. 4 | 5 | ## Final Performance Runs 6 | 7 | * Choose best model according to BERTScore on 100 validation examples 8 | * Then evluate on 100 test examples 9 | 10 | ### LED-base 11 | 12 | * Training with paramter tuning on wandb 13 | 14 | Evaluation on long data test set: 15 | ``` 16 | python summarization/run_summarization.py --model_name_or_path ~/scratch/mimic-iv-note-di-bhc/models/led-base-16384/mimic-iv-note-di-bhc_led-base-16384_long_data_100_valid_done/dropout_0.05_learning_rate_5e-5/ --do_predict --test_file ~/scratch/mimic-iv-note-di-bhc/dataset/test_last_100.json --output_dir ~/scratch/mimic-iv-note-di-bhc/models/led-base-16384/mimic-iv-note-di-bhc_led-base-16384_long_data_100_valid_done/test_160000_output --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --predict_with_generate --max_source_length 4096 --max_target_length 350 17 | 18 | {'rouge1': 43.32018605114362, 'rouge2': 17.04703537496032, 'rouge3': 8.255309078577513, 'rouge4': 4.301309886142394, 'rougeL': 29.214412778548695, 'words': 74.36, 'bert_score': 87.98089069128036, 'bert_score_deberta-large': 63.5219652056694, 'sari': 46.390301194434564} 19 | $43.32$ & $17.05$ & $8.26$ & $4.30$ & $29.21$ & $87.98$ & $63.52$ & $46.39$ & $74.36$ 20 | ``` 21 | 22 | 23 | Evaluation on 4000_600_chars test set: 24 | ``` 25 | python summarization/run_summarization.py --model_name_or_path ~/scratch/mimic-iv-note-di-bhc/models/led-base-16384/mimic-iv-note-di-bhc_led-base-16384_4000_600_chars_100_valid_done/dropout_0.2_learning_rate_1e-5/ --do_predict --test_file ~/scratch/mimic-iv-note-di-bhc/dataset/test_4000_600_chars_last_100.json --output_dir ~/scratch/mimic-iv-note-di-bhc/models/led-base-16384/mimic-iv-note-di-bhc_led-base-16384_4000_600_chars_100_valid_done/test_160000_output --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --predict_with_generate --max_source_length 4096 --max_target_length 350 26 | 27 | {'rouge1': 42.2980476882839, 'rouge2': 14.976131359106965, 'rouge3': 7.043920762758955, 'rouge4': 3.872723295137177, 'rougeL': 26.502820195241306, 'words': 117.81, 'bert_score': 86.71253889799118, 'bert_score_deberta-large': 60.84960976243019, 'sari': 44.38245205873611} 28 | $42.30$ & $14.98$ & $7.04$ & $3.87$ & $26.50$ & $86.71$ & $60.85$ & $44.38$ & $117.81$ 29 | ``` 30 | 31 | ### LED-large 32 | 33 | * Training with paramter tuning on wandb 34 | 35 | Evaluation on long data test set: 36 | ``` 37 | CUDA_VISIBLE_DEVICES=0 python summarization/run_summarization.py --model_name_or_path ~/scratch/mimic-iv-note-di-bhc/models/led-large-16384/mimic-iv-note-di-bhc_led-large-16384_long_data_100_valid_done/dropout_0.2_learning_rate_5e-5/ --do_predict --test_file ~/scratch/mimic-iv-note-di-bhc/dataset/test_last_100.json --output_dir ~/scratch/mimic-iv-note-di-bhc/models/led-large-16384/mimic-iv-note-di-bhc_led-large-16384_long_data_100_valid_done/test_160000_output --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --predict_with_generate --max_source_length 4096 --max_target_length 350 38 | 39 | {'rouge1': 43.822986792431145, 'rouge2': 17.331937668711536, 'rouge3': 8.846417725899846, 'rouge4': 4.9231165894123174, 'rougeL': 29.887635210786037, 'words': 76.99, 'bert_score': 88.10762703418732, 'bert_score_deberta-large': 64.11935889720917, 'sari': 46.71329387536656} 40 | $43.82$ & $17.33$ & $8.85$ & $4.92$ & $29.89$ & $88.11$ & $64.12$ & $46.71$ & $76.99$ 41 | ``` 42 | 43 | 44 | Evaluation on 4000_600_chars test set: 45 | ``` 46 | CUDA_VISIBLE_DEVICES=2 python summarization/run_summarization.py --model_name_or_path ~/scratch/mimic-iv-note-di-bhc/models/led-large-16384/mimic-iv-note-di-bhc_led-large-16384_4000_600_chars_100_valid_done/dropout_0.2_learning_rate_1e-5/ --do_predict --test_file ~/scratch/mimic-iv-note-di-bhc/dataset/test_4000_600_chars_last_100.json --output_dir ~/scratch/mimic-iv-note-di-bhc/models/led-large-16384/mimic-iv-note-di-bhc_led-large-16384_4000_600_chars_100_valid_done/test_160000_output --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --predict_with_generate --max_source_length 4096 --max_target_length 350 47 | 48 | {'rouge1': 46.206983410265465, 'rouge2': 17.376355377659934, 'rouge3': 8.715205915075776, 'rouge4': 5.13851072281858, 'rougeL': 28.865386144461212, 'words': 117.59, 'bert_score': 87.4952802658081, 'bert_score_deberta-large': 63.51826936006546, 'sari': 45.84487176324999} 49 | $46.21$ & $17.38$ & $8.72$ & $5.14$ & $28.87$ & $87.50$ & $63.52$ & $45.84$ & $117.59$ 50 | ``` 51 | 52 | ### LED-large long data 53 | 54 | 55 | 56 | ### LLaMA 7B 57 | 58 | * Training with paramter tuning on wandb 59 | 60 | #### Long data 61 | * Best model uses only _steps=80_ (best model in best_val_loss folder) 62 | 63 | Evaluation on long data test set: 64 | ``` 65 | # With data_files["test"] = test_last_100.json 66 | 67 | python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-70b-hf --evaluation --evaluation_model_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_long_data_100_valid_done/lora_rank_32_lora_alpha_32_lora_dropout_0.05_num_target_modules_2_learning_rate_2e-4/best_val_loss --data_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset --output_path /home/s/s_hegs02/scratch/debug --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 68 | 69 | {'rouge1': 38.8343594182866, 'rouge2': 12.97963450275313, 'rouge3': 5.273994688087981, 'rouge4': 2.388816250393242, 'rougeL': 24.811244853956904, 'words': 71.92, 'bert_score': 86.49026721715927, 'bert_score_deberta-large': 60.77595293521881, 'sari': 44.27485176179068} 70 | {'rouge1': 39.47507417776543, 'rouge2': 13.508055870992669, 'rouge3': 5.474262681411284, 'rouge4': 2.4299066163652707, 'rougeL': 25.201416964411987, 'words': 77.79, 'bert_score': 84.75045335292816, 'bert_score_deberta-large': 60.05594950914382, 'sari': 44.487989828395825} 71 | {'rouge1': 38.167512768514875, 'rouge2': 12.432498880471945, 'rouge3': 5.166171874893524, 'rouge4': 2.2279453931573583, 'rougeL': 24.912878164301006, 'words': 69.49, 'bert_score': 86.43438655138016, 'bert_score_deberta-large': 60.743097960948944, 'sari': 44.387922498396605} 72 | {'rouge1': 38.474542505246575, 'rouge2': 12.59729023685818, 'rouge3': 5.032502045215025, 'rouge4': 2.08290036294411, 'rougeL': 24.771698689284257, 'words': 71.22, 'bert_score': 86.25028610229492, 'bert_score_deberta-large': 60.663696229457855, 'sari': 43.91790431156874} 73 | {'rouge1': 36.83328812388325, 'rouge2': 11.758537910375646, 'rouge3': 4.723083386770241, 'rouge4': 2.0583976340860497, 'rougeL': 23.944460740287273, 'words': 75.22, 'bert_score': 84.4628956913948, 'bert_score_deberta-large': 58.90815430879592, 'sari': 43.55259789375479} 74 | ``` 75 | 76 | 77 | #### 4000_600_chars 78 | 79 | * Training with paramter tuning on wandb 80 | * Best model uses steps=100 81 | 82 | Evaluation on 4000_600_chars test set: 83 | ``` 84 | # With data_files["test"] = test_4000_600_chars_last_100.json 85 | 86 | python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-7b-hf --evaluation --evaluation_model_path /home/s_hegs02/mimic-iv-note-di-bhc/models/Llama-2-7b-hf/mimic-iv-note-di-bhc_Llama-2-7b-hf_4000_600_chars_100_valid_done/lora_rank_8_lora_alpha_32_lora_dropout_0.1_num_target_modules_2_learning_rate_2e-5/best_val_loss --data_path /home/s_hegs02/mimic-iv-note-di-bhc/dataset --output_path /home/s_hegs02/mimic-iv-note-di-bhc/models/Llama-2-7b-hf/mimic-iv-note-di-bhc_Llama-2-7b-hf_4000_600_chars_100_valid_done/lora_rank_8_lora_alpha_32_lora_dropout_0.1_num_target_modules_2_learning_rate_2e-5_test --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 87 | 88 | {'rouge1': 38.18726652084839, 'rouge2': 12.522356326993874, 'rouge3': 5.300379506691891, 'rouge4': 2.6844528863937667, 'rougeL': 23.425086591887453, 'words': 104.15, 'bert_score': 84.68257343769073, 'bert_score_deberta-large': 58.7152236700058, 'sari': 43.39504555237768} 89 | {'rouge1': 36.55885611308579, 'rouge2': 11.770071884328013, 'rouge3': 4.96413117973415, 'rouge4': 2.3760043053086815, 'rougeL': 22.520381900800416, 'words': 100.56, 'bert_score': 82.0653041601181, 'bert_score_deberta-large': 56.56696653366089, 'sari': 41.94318296286558} 90 | {'rouge1': 36.865581109022166, 'rouge2': 12.035793003259116, 'rouge3': 5.477773544226487, 'rouge4': 2.831216527519416, 'rougeL': 22.85251778524732, 'words': 94.03, 'bert_score': 81.28817218542099, 'bert_score_deberta-large': 56.447242498397834, 'sari': 42.631882102051144} 91 | {'rouge1': 36.45585401833812, 'rouge2': 11.507301855132011, 'rouge3': 4.979182106965462, 'rouge4': 2.525371873858721, 'rougeL': 22.339373505129323, 'words': 100.47, 'bert_score': 82.16391408443451, 'bert_score_deberta-large': 56.96498113870621, 'sari': 42.04524834049528} 92 | {'rouge1': 36.698585402455606, 'rouge2': 11.78729495150544, 'rouge3': 4.868574815166207, 'rouge4': 2.2399180637599803, 'rougeL': 22.530625996895047, 'words': 103.47, 'bert_score': 81.9952797293663, 'bert_score_deberta-large': 56.640745639801025, 'sari': 42.02560221098953} 93 | ``` 94 | 95 | 96 | #### Train on hallucination dataset 97 | 98 | Important Parameters: 99 | * Set max_steps to 100 (best in parameter tuning) 100 | * Set gradient_accumulation_steps to 16 101 | 102 | Set training data (`data_files["train"]`) depending on setting: 103 | * Original: `hallucination_summaries_original_test_100.json` 104 | * Cleaned: `hallucination_summaries_cleaned_test_100.json` 105 | * Cleaned and Improved: `hallucination_summaries_cleaned_improved_test_100.json` 106 | 107 | Careful: Set `output_path` accordingly for each setting! 108 | 109 | ``` 110 | CUDA_VISIBLE_DEVICES=0 python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-7b-hf --data_path /home/s_hegs02/mimic-iv-note-di-bhc/dataset --output_path /home/s_hegs02/mimic-iv-note-di-bhc/models/Llama-2-7b-hf/mimic-iv-note-di-bhc_Llama-2-7b-hf_4000_600_chars_100_hallucinations/original --device cuda --max_steps 100 --save_and_logging_steps 100 --batch_size 1 --gradient_accumulation_steps 16 --lora_rank 8 --lora_alpha 32 --lora_dropout 0.1 --num_target_modules 2 --learning_rate 2e-5 --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 111 | ``` 112 | 113 | 114 | ### LLaMA 70B 115 | 116 | * Training with paramter tuning on wandb 117 | * Both 70B models best with steps=20 (folder checkpoint_20 = best_val_loss) 118 | 119 | #### Long data 120 | 121 | * Best model uses only _steps=20_ 122 | 123 | Evaluation on long data test set: 124 | ``` 125 | # With data_files["test"] = test_last_100.json 126 | 127 | python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-70b-hf --evaluation --evaluation_model_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_long_data_100_valid_done/lora_rank_32_lora_alpha_32_lora_dropout_0.05_num_target_modules_2_learning_rate_2e-4/best_val_loss --data_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset --output_path /home/s/s_hegs02/scratch/debug --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 128 | 129 | 130 | {'rouge1': 40.5555217296232, 'rouge2': 14.380567615873911, 'rouge3': 6.145952549654063, 'rouge4': 2.9839005077717538, 'rougeL': 25.812246308651094, 'words': 76.46, 'bert_score': 86.00773721933365, 'bert_score_deberta-large': 61.7718161046505, 'sari': 45.00554637223204} 131 | {'rouge1': 40.54058585622222, 'rouge2': 14.166953530916869, 'rouge3': 5.695409156176871, 'rouge4': 2.3296289556276344, 'rougeL': 26.20942324367626, 'words': 80.28, 'bert_score': 86.76197403669357, 'bert_score_deberta-large': 62.13716307282448, 'sari': 44.9823358329503} 132 | {'rouge1': 40.34610201932814, 'rouge2': 14.411568668138443, 'rouge3': 6.339125512527877, 'rouge4': 2.953904152985041, 'rougeL': 26.32607735907873, 'words': 75.4, 'bert_score': 86.87447029352188, 'bert_score_deberta-large': 62.064355462789536, 'sari': 45.71918948150783} 133 | {'rouge1': 40.1893462851726, 'rouge2': 14.055680359041899, 'rouge3': 5.956359809720371, 'rouge4': 2.5952588267124463, 'rougeL': 25.947485210072838, 'words': 74.67, 'bert_score': 85.03679966926575, 'bert_score_deberta-large': 61.02019691467285, 'sari': 44.80799764604085} 134 | {'rouge1': 41.26608906368885, 'rouge2': 14.527248259553481, 'rouge3': 6.298794377100548, 'rouge4': 2.8508339445193314, 'rougeL': 26.633192382147357, 'words': 77.67, 'bert_score': 86.83461207151413, 'bert_score_deberta-large': 62.4643184542656, 'sari': 45.28758563082025} 135 | ``` 136 | 137 | 138 | #### 4000_600_chars 139 | 140 | * Best model uses only _steps=20_ 141 | 142 | Evaluation on 4000_600_chars test set: 143 | ``` 144 | # With data_files["test"] = test_4000_600_chars_last_100.json 145 | 146 | python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-70b-hf --evaluation --evaluation_model_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_4000_600_chars_100_valid_done/lora_rank_32_lora_alpha_32_lora_dropout_0.1_num_target_modules_2_learning_rate_2e-4/best_val_loss --data_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset --output_path /home/s/s_hegs02/scratch/debug --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 147 | 148 | {'rouge1': 42.1991438443541, 'rouge2': 13.563337998083213, 'rouge3': 5.7237180188916845, 'rouge4': 2.6722184607364134, 'rougeL': 24.777895343443596, 'words': 111.31, 'bert_score': 87.0364066362381, 'bert_score_deberta-large': 61.91603672504426, 'sari': 44.16604762927569} 149 | {'rouge1': 41.67500471009928, 'rouge2': 13.840969941914366, 'rouge3': 5.947293259680258, 'rouge4': 2.844910436528473, 'rougeL': 24.836899438706126, 'words': 120.9, 'bert_score': 85.90889036655426, 'bert_score_deberta-large': 60.57257956266403, 'sari': 44.07321356056379} 150 | {'rouge1': 41.88252074463912, 'rouge2': 14.074882768448884, 'rouge3': 6.039503368150073, 'rouge4': 2.8148856591135196, 'rougeL': 25.643550342550594, 'words': 112.34, 'bert_score': 87.158194065094, 'bert_score_deberta-large': 62.24099487066269, 'sari': 44.35337536946284} 151 | {'rouge1': 41.83728919015555, 'rouge2': 13.09889608467368, 'rouge3': 5.487954667140488, 'rouge4': 2.4686725964920826, 'rougeL': 24.435506275256007, 'words': 112.07, 'bert_score': 86.0190686583519, 'bert_score_deberta-large': 61.00947117805481, 'sari': 43.20317982433489} 152 | {'rouge1': 41.5120357352563, 'rouge2': 13.574013168402391, 'rouge3': 5.627399063893296, 'rouge4': 2.4955528332014443, 'rougeL': 24.480105189357122, 'words': 113.79, 'bert_score': 86.04067796468735, 'bert_score_deberta-large': 60.975450932979584, 'sari': 43.51100259877742} 153 | ``` 154 | 155 | 156 | #### Train on hallucination dataset 157 | 158 | Important Parameters: 159 | * Set max_steps to 20 (best in parameter tuning) 160 | * Set gradient_accumulation_steps to 16 161 | 162 | Set training data (`data_files["train"]`) depending on setting: 163 | * Original: `hallucination_summaries_original_test_100.json` 164 | * Cleaned: `hallucination_summaries_cleaned_test_100.json` 165 | * Cleaned and Improved: `hallucination_summaries_cleaned_improved_test_100.json` 166 | 167 | Careful: Set `output_path` accordingly for each setting! 168 | 169 | ``` 170 | CUDA_VISIBLE_DEVICES=0,1 python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-70b-hf --data_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset --output_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_4000_600_chars_100_hallucinations/original --device cuda --max_steps 20 --save_and_logging_steps 20 --batch_size 1 --gradient_accumulation_steps 16 --lora_rank 32 --lora_alpha 32 --lora_dropout 0.1 --num_target_modules 2 --learning_rate 2e-4 --num_train_examples 100 --num_val_examples 100 --num_test_examples 100 171 | ``` 172 | 173 | Prediction on hallucination dataset: 174 | 175 | Careful: Set `data_files["train"] = test_4000_600_chars_last_50.json` for each setting! 176 | 177 | ``` 178 | CUDA_VISIBLE_DEVICES=0 python summarization/fine_tune_llama.py --model_name_or_path meta-llama/Llama-2-70b-hf --evaluation --evaluation_model_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_4000_600_chars_100_hallucinations/original/checkpoint-20 --data_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/dataset --output_path /home/s/s_hegs02/scratch/mimic-iv-note-di-bhc/models/Llama-2-70b-hf/mimic-iv-note-di-bhc_Llama-2-70b-hf_4000_600_chars_100_hallucinations/original_50_test --num_train_examples 100 --num_val_examples 100 --num_test_examples 50 179 | ``` -------------------------------------------------------------------------------- /notebooks/build_annotation_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Build datasets for Hallucination Evaluation and Qualitative Evaluation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import json\n", 17 | "import random\n", 18 | "import os\n", 19 | "import shutil\n", 20 | "import datetime\n", 21 | "import zipfile\n", 22 | "import re" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Read files\n", 32 | "# Read jsonl files\n", 33 | "def read_jsonl(file_name):\n", 34 | " with open(file_name, \"r\") as f:\n", 35 | " return [json.loads(line) for line in f]\n", 36 | " \n", 37 | "def read_txt(file_name):\n", 38 | " list = []\n", 39 | " with open(file_name, \"r\") as f:\n", 40 | " list = f.read().splitlines()\n", 41 | " return [{\"summary\": s} for s in list]\n", 42 | "\n", 43 | "def format_gpt4(gpt4_examples):\n", 44 | " return [{\"text\": e[\"question\"], \"summary\": e[\"summary\"]} for e in gpt4_examples]\n", 45 | "\n", 46 | "# Hallucination evaluation\n", 47 | "llama_70b_original = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/llama_70b_original_predictions_test_dict.jsonl\")\n", 48 | "llama_70b_cleaned = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/llama_70b_cleaned_predictions_test_dict.jsonl\")\n", 49 | "gpt4_zero_shot = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/gpt-4_exp4_results_prompt3.1_0shot.jsonl\")\n", 50 | "gpt4_orig = format_gpt4(read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/gpt-4_exp4_results_prompt3_5shot.jsonl\"))\n", 51 | "gpt4_cleaned = format_gpt4(read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/hallucination_evaluation/gpt-4_exp5_results_prompt3_5shot.jsonl\"))\n", 52 | "hallucination_models = [llama_70b_original, llama_70b_cleaned, gpt4_zero_shot, gpt4_orig, gpt4_cleaned]\n", 53 | "num_hallucination_models = len(hallucination_models)\n", 54 | "\n", 55 | "# Qualitative evaluation\n", 56 | "original_examples = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/orig_test_4000_600_chars_last_50.json\")\n", 57 | "led_large_original = read_txt(\"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/led-large_predictions_test_dict.txt\")\n", 58 | "llama_70b_cleaned_improved = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/llama_70b_cleaned_improved_predictions_test_dict.jsonl\")\n", 59 | "gpt4_zero_shot = read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/gpt-4_exp4_results_prompt3.1_0shot.jsonl\")\n", 60 | "gpt4_cleaned_improved = format_gpt4(read_jsonl(\"/home/s_hegs02/patient_summaries_with_llms/data/qualitative_evaluation/gpt-4_exp6_results_prompt3_5shot.jsonl\"))\n", 61 | "qualitative_models = [original_examples, led_large_original, llama_70b_cleaned_improved, gpt4_zero_shot, gpt4_cleaned_improved]\n", 62 | "num_qualitative_models = len(qualitative_models)\n", 63 | "\n", 64 | "\n", 65 | "num_examples = len(original_examples)\n", 66 | "# Assert each model has the same number of examples\n", 67 | "assert all(len(model) == num_examples for model in qualitative_models)\n", 68 | "assert all(len(model) == num_examples for model in hallucination_models)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# Debug print summaries with cetain id\n", 78 | "# Print summary at position 33 and 34 for all hallucination models\n", 79 | "# Was done to check if hallucination ratios after permutation - labelingreverse permutation are correct\n", 80 | "\n", 81 | "# for model in hallucination_models:\n", 82 | "# print(model[33][\"summary\"])\n", 83 | "# print()\n", 84 | "# for model in hallucination_models:\n", 85 | "# print(model[34][\"summary\"])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Hallucination models: \n", 98 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 2}\n", 99 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 9}\n", 100 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 0}\n", 101 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 0}\n", 102 | "{'summary prefixes': 3, 'newlines': 3, 'repeated_spaces': 0}\n", 103 | "Qualitative models: \n", 104 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 0}\n", 105 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 0}\n", 106 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 6}\n", 107 | "{'summary prefixes': 0, 'newlines': 0, 'repeated_spaces': 0}\n", 108 | "{'summary prefixes': 3, 'newlines': 3, 'repeated_spaces': 0}\n", 109 | "\n", 110 | "Cleaning summaries...\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "# Count the number of newlines and repeated spaces in each summary\n", 116 | "def count_newlines_and_repeated_spaces(examples):\n", 117 | " total_summary_prefixes = 0\n", 118 | " total_newlines = 0\n", 119 | " total_repeated_spaces = 0\n", 120 | " for e in examples:\n", 121 | " summary = e[\"summary\"]\n", 122 | " # Count SUMMMARY: at beginning of summary\n", 123 | " total_summary_prefixes += summary.lower().startswith(\"summary:\\n\")\n", 124 | " total_newlines += summary.count(\"\\n\")\n", 125 | " total_repeated_spaces += summary.count(\" \")\n", 126 | " return {\"summary prefixes\": total_summary_prefixes, \"newlines\": total_newlines, \"repeated_spaces\": total_repeated_spaces}\n", 127 | "\n", 128 | "print(\"Hallucination models: \")\n", 129 | "print('\\n'.join([str(count_newlines_and_repeated_spaces(model)) for model in hallucination_models]))\n", 130 | "print(\"Qualitative models: \")\n", 131 | "print('\\n'.join([str(count_newlines_and_repeated_spaces(model)) for model in qualitative_models]))\n", 132 | "\n", 133 | "# Clean all summaries from newlines and repeated spaces, change to single spaces\n", 134 | "def clean_summaries(examples):\n", 135 | " for e in examples:\n", 136 | " # Some gpt-4 examples start with SUMMARY:, to prevent identifying the model, remove it\n", 137 | " if e[\"summary\"].lower().startswith(\"summary:\\n\"):\n", 138 | " e[\"summary\"] = e[\"summary\"][9:]\n", 139 | " # Remove newlines and repeated spaces\n", 140 | " e[\"summary\"] = \" \".join(e[\"summary\"].split())\n", 141 | " \n", 142 | " return examples\n", 143 | "\n", 144 | "print(\"\\nCleaning summaries...\")\n", 145 | "hallucination_models = [clean_summaries(model) for model in hallucination_models]\n", 146 | "qualitative_models = [clean_summaries(model) for model in qualitative_models]" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 4, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "# Debugging\n", 156 | "num_examples = 50" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 5, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "hallucination_random_models: {0: [0, 4, 3, 1, 2], 1: [4, 1, 3, 2, 0], 2: [0, 2, 1, 3, 4], 3: [1, 0, 3, 4, 2], 4: [3, 2, 4, 1, 0], 5: [3, 2, 1, 4, 0], 6: [1, 2, 3, 0, 4], 7: [1, 3, 2, 4, 0], 8: [4, 0, 1, 3, 2], 9: [0, 3, 2, 4, 1], 10: [0, 4, 3, 2, 1], 11: [1, 0, 4, 3, 2], 12: [2, 4, 1, 0, 3], 13: [3, 1, 0, 4, 2], 14: [4, 2, 0, 1, 3], 15: [0, 2, 4, 3, 1], 16: [1, 4, 2, 3, 0], 17: [2, 3, 1, 0, 4], 18: [4, 0, 3, 2, 1], 19: [0, 3, 1, 2, 4], 20: [4, 0, 2, 3, 1], 21: [0, 4, 2, 1, 3], 22: [0, 2, 4, 3, 1], 23: [1, 0, 3, 4, 2], 24: [3, 1, 0, 4, 2], 25: [2, 0, 3, 4, 1], 26: [4, 3, 0, 1, 2], 27: [3, 4, 2, 1, 0], 28: [4, 2, 3, 1, 0], 29: [4, 1, 3, 0, 2], 30: [2, 3, 0, 1, 4], 31: [4, 2, 0, 3, 1], 32: [3, 0, 2, 1, 4], 33: [2, 3, 4, 1, 0], 34: [4, 1, 3, 2, 0], 35: [0, 4, 1, 3, 2], 36: [4, 1, 3, 0, 2], 37: [3, 1, 0, 4, 2], 38: [3, 2, 4, 1, 0], 39: [1, 0, 3, 4, 2], 40: [4, 3, 0, 1, 2], 41: [2, 3, 4, 0, 1], 42: [2, 4, 3, 1, 0], 43: [4, 1, 2, 0, 3], 44: [0, 4, 3, 1, 2], 45: [3, 2, 0, 1, 4], 46: [2, 4, 0, 3, 1], 47: [2, 1, 0, 4, 3], 48: [4, 2, 3, 1, 0], 49: [3, 1, 4, 2, 0]}\n", 169 | "qualitative_random_models: {0: [2, 4, 3, 0, 1], 1: [4, 3, 2, 1, 0], 2: [3, 1, 2, 0, 4], 3: [1, 2, 3, 0, 4], 4: [3, 1, 4, 2, 0], 5: [3, 2, 4, 1, 0], 6: [3, 2, 1, 4, 0], 7: [2, 0, 3, 4, 1], 8: [4, 1, 3, 0, 2], 9: [2, 4, 0, 3, 1], 10: [0, 2, 1, 4, 3], 11: [1, 0, 3, 4, 2], 12: [3, 0, 1, 2, 4], 13: [0, 3, 4, 1, 2], 14: [2, 3, 4, 0, 1], 15: [1, 0, 4, 3, 2], 16: [2, 0, 4, 3, 1], 17: [0, 3, 1, 4, 2], 18: [0, 1, 4, 2, 3], 19: [1, 4, 2, 3, 0], 20: [4, 1, 0, 3, 2], 21: [2, 4, 0, 1, 3], 22: [0, 3, 2, 1, 4], 23: [2, 4, 3, 0, 1], 24: [3, 1, 4, 2, 0], 25: [3, 1, 2, 0, 4], 26: [2, 3, 0, 1, 4], 27: [2, 3, 0, 4, 1], 28: [4, 0, 3, 1, 2], 29: [0, 1, 3, 2, 4], 30: [1, 3, 0, 4, 2], 31: [1, 3, 4, 0, 2], 32: [3, 0, 1, 2, 4], 33: [4, 2, 3, 0, 1], 34: [2, 3, 4, 1, 0], 35: [1, 0, 4, 3, 2], 36: [4, 3, 1, 2, 0], 37: [3, 0, 2, 1, 4], 38: [0, 2, 4, 1, 3], 39: [3, 2, 0, 4, 1], 40: [4, 3, 2, 1, 0], 41: [4, 2, 0, 1, 3], 42: [4, 0, 2, 3, 1], 43: [0, 1, 3, 4, 2], 44: [1, 2, 4, 0, 3], 45: [2, 3, 0, 1, 4], 46: [3, 1, 2, 4, 0], 47: [4, 3, 1, 2, 0], 48: [4, 0, 1, 2, 3], 49: [3, 0, 4, 1, 2]}\n" 170 | ] 171 | } 172 | ], 173 | "source": [ 174 | "# Add randomness\n", 175 | "# Set reproducible seed\n", 176 | "random.seed(2)\n", 177 | "\n", 178 | "def get_random_permutation(max_num):\n", 179 | " return random.sample(range(max_num), max_num)\n", 180 | " # Debug\n", 181 | " # return list(range(max_num))\n", 182 | "\n", 183 | "hallucination_random_models = {}\n", 184 | "qualitative_random_models = {}\n", 185 | "\n", 186 | "for id in range(num_examples):\n", 187 | " hallucination_random_models[id] = get_random_permutation(num_hallucination_models)\n", 188 | " qualitative_random_models[id] = get_random_permutation(num_qualitative_models)\n", 189 | "\n", 190 | "print(\"hallucination_random_models:\", hallucination_random_models)\n", 191 | "print(\"qualitative_random_models:\", qualitative_random_models)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# Get hallucination examples\n", 201 | "hallucination_summaries = []\n", 202 | "for id in range(num_examples):\n", 203 | " text = original_examples[id][\"text\"]\n", 204 | " summaries = []\n", 205 | " for j in range(num_hallucination_models):\n", 206 | " random_index = hallucination_random_models[id][j]\n", 207 | " summaries.append(hallucination_models[random_index][id][\"summary\"])\n", 208 | " hallucination_summaries.append({\"text\": text, \"summaries\": summaries})\n", 209 | "\n", 210 | "# Get qualitative examples\n", 211 | "qualitative_summaries = []\n", 212 | "for id in range(num_examples):\n", 213 | " text = original_examples[id][\"text\"]\n", 214 | " summaries = []\n", 215 | " for j in range(num_qualitative_models):\n", 216 | " random_index = qualitative_random_models[id][j]\n", 217 | " summaries.append(qualitative_models[random_index][id][\"summary\"])\n", 218 | " qualitative_summaries.append({\"text\": text, \"summaries\": summaries})" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "# Print hallucination examples\n", 228 | "\n", 229 | "for i in range(num_examples):\n", 230 | " print(f\"Text:\\n{hallucination_summaries[i]['text']}\\n\")\n", 231 | " for j in range(num_hallucination_models):\n", 232 | " print(f\"Summary {j}:\\n{hallucination_summaries[i]['summaries'][j]}\\n\")\n", 233 | " print(\"\\n\\n\")" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "# Print qualitative examples\n", 243 | "\n", 244 | "for i in range(num_examples):\n", 245 | " print(f\"Text:\\n{qualitative_summaries[i]['text']}\\n\")\n", 246 | " for j in range(num_qualitative_models):\n", 247 | " print(f\"Summary {j}:\\n{qualitative_summaries[i]['summaries'][j]}\\n\")\n", 248 | " print(\"\\n\\n\")" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 34, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# Clean results folder and store examples in the folder\n", 258 | "hallucination_folder = \"/home/s_hegs02/patient_summaries_with_llms/data/results_hallucinations\"\n", 259 | "qualitative_folder = \"/home/s_hegs02/patient_summaries_with_llms/data/results_qualitative\"\n", 260 | "\n", 261 | "# Clean results folder\n", 262 | "def clean_results_folder(folder):\n", 263 | " if os.path.exists(folder):\n", 264 | " shutil.rmtree(folder)\n", 265 | " os.makedirs(folder)\n", 266 | " \n", 267 | "clean_results_folder(hallucination_folder)\n", 268 | "clean_results_folder(qualitative_folder)\n", 269 | "\n", 270 | "# Store hallucination examples\n", 271 | "for i in range(num_examples):\n", 272 | " hallucination_file = os.path.join(hallucination_folder, f\"{i}_hallucination.txt\")\n", 273 | " with open(hallucination_file, \"w\") as f:\n", 274 | " f.write(f\"Text:\\n{hallucination_summaries[i]['text']}\\n\\n\")\n", 275 | " for j in range(num_hallucination_models):\n", 276 | " f.write(f\"Summary {j}:\\n{hallucination_summaries[i]['summaries'][j]}\\n\\n\")\n", 277 | " f.write(\"\\n\\n\")\n", 278 | " \n", 279 | "# Store qualitative examples\n", 280 | "for i in range(num_examples):\n", 281 | " qualitative_file = os.path.join(qualitative_folder, f\"{i}_qualitative.txt\")\n", 282 | " with open(qualitative_file, \"w\") as f:\n", 283 | " f.write(f\"Text:\\n{qualitative_summaries[i]['text']}\\n\\n\")\n", 284 | " for j in range(num_qualitative_models):\n", 285 | " f.write(f\"Summary {j}:\\n{qualitative_summaries[i]['summaries'][j]}\\n\\n\")\n", 286 | " f.write(\"\\n\\n\")\n", 287 | "\n", 288 | "def zipdir(path, ziph):\n", 289 | " # ziph is zipfile handle\n", 290 | " for root, dirs, files in os.walk(path):\n", 291 | " for file in files:\n", 292 | " # ziph.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(path, '..')))\n", 293 | " ziph.write(os.path.join(root, file), file)\n", 294 | " \n", 295 | "zip_hallucination_file_name = f\"hallucinations_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.zip\"\n", 296 | "zip_qualitative_file_name = f\"qualitative_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.zip\"\n", 297 | "\n", 298 | "zipf_hallucination = zipfile.ZipFile(zip_hallucination_file_name, 'w', zipfile.ZIP_DEFLATED)\n", 299 | "zipdir(hallucination_folder, zipf_hallucination)\n", 300 | "zipf_hallucination.close()\n", 301 | "zipf_qualitative = zipfile.ZipFile(zip_qualitative_file_name, 'w', zipfile.ZIP_DEFLATED)\n", 302 | "zipdir(qualitative_folder, zipf_qualitative)\n", 303 | "zipf_qualitative.close()" 304 | ] 305 | } 306 | ], 307 | "metadata": { 308 | "kernelspec": { 309 | "display_name": "avs_gen", 310 | "language": "python", 311 | "name": "python3" 312 | }, 313 | "language_info": { 314 | "codemirror_mode": { 315 | "name": "ipython", 316 | "version": 3 317 | }, 318 | "file_extension": ".py", 319 | "mimetype": "text/x-python", 320 | "name": "python", 321 | "nbconvert_exporter": "python", 322 | "pygments_lexer": "ipython3", 323 | "version": "3.9.18" 324 | } 325 | }, 326 | "nbformat": 4, 327 | "nbformat_minor": 2 328 | } 329 | -------------------------------------------------------------------------------- /preprocess/regular_expressions.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | from src.preprocess.constants import * 4 | 5 | # Globally used regexes 6 | re_whitespace = re.compile(r'\s+', re.MULTILINE) 7 | re_multiple_whitespace = re.compile(r' +', re.MULTILINE) 8 | re_paragraph = re.compile(r'\n{2,}', re.MULTILINE) 9 | re_line_punctuation = re.compile(r'^(?:\.|!|\"|#|\$|%|&|\'|\(|\)|\*|\+|,|\/|:|;|<|=|>|\?|@|\[|\\|\]|\^|_|`|\{|\||\}|\||~|»|«|“|”|-|_)+$', re.MULTILINE) 10 | re_line_punctuation_wo_fs = re.compile(r'^(?:!|\"|#|\$|%|&|\'|\(|\)|\*|\+|,|\/|:|;|<|=|>|\?|@|\[|\\|\]|\^|_|`|\{|\||\}|\||~|»|«|“|”|-|_)+$', re.MULTILINE) 11 | re_line_punctuation_wo_underscore = re.compile(r'^(?:\.|!|\"|#|\$|%|&|\'|\(|\)|\*|\+|,|\/|:|;|<|=|>|\?|@|\[|\\|\]|\^|`|\{|\||\}|\||~|»|«|“|”|-)+$', re.MULTILINE) 12 | re_ds_punctuation_wo_underscore = re.compile(r'^(?:\.|!|\"|#|\$|%|&|\'|\(|\)|\*|\+|,|\/|:|;|<|=|>|\?|@|\[|\\|\]|\^|`|\{|\||\}|\||~|»|«|“|”|-)+') 13 | re_fullstop = re.compile(r'^(?:\.)+$', re.MULTILINE) 14 | re_newline_in_text = re.compile(r'(?<=\w)\n(?=\w)', re.MULTILINE) 15 | re_incomplete_sentence_at_end = re.compile(r'(?<=\.)[^\.]+$', re.DOTALL) 16 | re_item_element = r'(?:-|\. |\*|•|\d+ |\d+\.|\d\)|\(\d+\)|\d\)\.|o |# )' 17 | re_heading_general = r'[^\.\:\n]*(?::\n{1,2}|\?\n{1,2}|[^,]\n)' 18 | 19 | # Itemize elements identified in mimic notes 20 | ITEMIZE_ELEMENTS = [r'-', r'\. ', r'\*', r'•', r'\d+ ', r'\d+\.', r'\d\)', r'\(\d+\)', r'\d\)\.', r'o ', r'# '] 21 | 22 | # determiend common patterns at the beginning of summaries that contain no information and should be removed. 23 | # manualy inspected 1000 summaries and found the following patterns: 24 | UNNECESSARY_SUMMARY_PREFIXES = { 25 | 'template separator': re.compile(r'^={5,40}', re.MULTILINE), 26 | 'template heading': re.compile(r'\A(?:Patient |CCU )?Discharge (?:Instructions|Worksheet):?\s*', re.IGNORECASE|re.DOTALL), 27 | 'salutations': re.compile(r'\A(?:___,|(?:Dear|Hello|Hi|Ms|Mrs|Miss|Mr|Dr)(?: Ms| Mrs| Miss| Mr| Dr)?\.{0,1} (?:___)?(?: and family| family)?(?:,|\.|:|;| ){0,3}|)\s*', re.IGNORECASE), 28 | # allow up to one sentences (.,!,:,;) before thank you and two before pleasure (more specific) and remove until end of sentence. 29 | 'thank you': re.compile(r'\A(?:[^\.!:;]*\.){0,1}[^\.!:;]*thank you[^\.!:;]*(?:\.|!|:|;)\s*', re.IGNORECASE|re.DOTALL), 30 | 'pleasure': re.compile(r'\A(?:[^\.!:;]*\.){0,2}[^\.!:;]*(?:pleasure|priviledge|privilege)[^\.!:;]*(?:\.|!|:|;)\s*', re.IGNORECASE|re.DOTALL), 31 | } 32 | 33 | # focus on occurences with typical headings followed by a dashed list, leave the other as they are 34 | WHY_WHAT_NEXT_HEADINGS = (r'^-{0,4}[^\S\r\n]{0,4}_{0,4}[^\S\r\n]{0,4}' # optional start via -- ___ ... 35 | r'(?:why .* admitted|why .* hospital|what brought .* hospital|why .* here|where .* hospital|why .* hospitalized|' # first section 36 | r'what was done|what .* hospital|was I .* hospital|when you .* hospital|what .* here|what .* admitted|what .* for you|what .* hospitalization|what .* stay|what happened .* ___|while .* here|' # second section 37 | r'what should .* next|what should .* hospital|what .* for me|when .* leave|what .* leave|what .* home|when .* home|what should .* leaving|what .* to do|' # third section 38 | r'when .* hospital|when .* come back|what .* discharge|what .* discharged)' 39 | r'(?:\?|:|\?:)?\n{1,2}') 40 | # add dash to ensure that item follows 41 | WHY_WHAT_NEXT_HEADINGS_DASHED_LIST = re.compile(WHY_WHAT_NEXT_HEADINGS + r'-', re.MULTILINE|re.IGNORECASE) 42 | 43 | # These are some common suffixes of 'you' refering to the patient and allowing to replace likely anonymization of '___' with 'you' 44 | YOU_SUFFIXES = ['were admitted', 'were here', 'were followed', 'were started', 'were found', 'were maintained', 'were able', 'were seen', 45 | 'were treated', 'were given', 'were told', 'were advised', 'were asked', 'were instructed', 'were recommended', 46 | 'were initially evaluated', 'were hospitalized', 'were complaining', 'were discharged', 'were also', 'were at', 47 | 'will not need', 'will need to follow', 'will need to', 'will start this' 48 | 'should hear', 'should follow', 49 | 'have recovered', 'have healed', 50 | 'are now ready', 'unfortunately developed', 'had chest pain', 'suffered', 'hit your head', 'vomited', 'can expect to see', 'tolerated the procedure'] 51 | SIMPLE_DEIDENTIFICATION_PATTERNS = [ 52 | ('You ', re.compile(r'(?:^|\. )___ (?=' + '|'.join(YOU_SUFFIXES) + r')', re.MULTILINE|re.IGNORECASE)), 53 | (' you ', re.compile(r'(?!' + ENCODE_STRINGS_DURING_PREPROCESSING['Dr.'] + ') ___ (?=' + '|'.join(YOU_SUFFIXES) + r')', re.MULTILINE|re.IGNORECASE)), # Prevent that Dr. ___ is replaced 54 | (' you', re.compile(r'(?:(?<=giving)|(?<=giving thank)|(?<=giving we wish)|(?<=giving scheduled)|(?<=giving will call)|(?<=we assessed)) ___', re.MULTILINE|re.IGNORECASE)), 55 | (' your ', re.compile(r' ___ (?=discharge|admission)', re.MULTILINE|re.IGNORECASE)), 56 | (' your ', re.compile(r'(?=directs all the other parts of|the brain is the part of|see occasional blood in) ___ ', re.MULTILINE|re.IGNORECASE)), # from neurology stroke / urology template 57 | ] 58 | 59 | 60 | # Identified patterns that mark the end of a summary with no relevant information following it. 61 | def create_heading_rs(heading): 62 | # Either delimited by : or starting and ending with newline 63 | return [heading + r':', r'(?:^|\n)' + heading + '\n'] 64 | SUFFIXES_DICT = { 65 | # follow-up 66 | "followup headings": create_heading_rs(r'follow(?:-| ||)(?:up)? instructions'), 67 | "followup sentences": [r'(?:you should|you have|you will|please)[^\.]{0,50} follow(?:-| ||)(?:up)?', 68 | r'(?:call|see|visit|attend)[^\.]{0,200} follow(?:-| ||)(?:up)?', 69 | r'follow(?:-| ||)(?:up)? with[^\.]{0,50} (?:primary care|pcp|doctor|neurologist|cardiologist)', 70 | r'you will [^\.]{10,50} (?:primary care|pcp|doctor|neurologist|cardiologist)', 71 | r'The number for [^\.]{10,200} is listed below'], 72 | # discharge 73 | "discharge headings": create_heading_rs(r'discharge instructions') + create_heading_rs(r'[^\.]{0,200} surgery discharge instructions'), 74 | "discharge sentences": [r'Please follow [^\.]{0,30}discharge instructions', 75 | r'(?:cleared|ready for)[^\.]{0,50} discharge', 76 | r'(?:are|were|being|will be) (?:discharge|sending)[^\.]{0,200} (?:home|rehab|facility|assisted|house)', 77 | r'(?:note|take)[^\.]{0,100} discharge (?:instruction|paperwork)', 78 | r'Below are your discharge instructions regarding'], 79 | # farewell 80 | "farewell pleasure": [r'It [^\.]{3,20} pleasure', r'was a pleasure'], 81 | "farewell priviledge": [r'It [^\.]{3,20} priviled?ge'], 82 | "farewell wish you": [r'wish(?:ing)? you', r'Best wishes', r'wish(?:ing)? [^\.]{0,20} luck'], 83 | "farewell general": [r'Sincerely', r'Warm regards', r'Thank you', r'Your[^\.]{0,10} care team', r'Your[^\.]{0,10} (?:doctor|PCP)'], 84 | 85 | # activity 86 | "activity headings": create_heading_rs(r'Activity') + create_heading_rs(r'Activity and [^\.]{4,20}'), 87 | # ama 88 | "ama sentences": [r'You [^\.]{0,60}decided to leave the hospital'], 89 | # appointments 90 | "appointments sentences": [r'(?:keep|follow|attend|go to|continue)[^\.]{1,100} (?:appointment|follow(?:-| ||)up)', 91 | r'(?:appointment|follow(?:-| ||)up)[^\.]{1,100} (?:arranged|scheduled|made)', 92 | r'(?:contact|call|in touch)[^\.]{1,100} (?:appointment|follow(?:-| ||)up)', 93 | r'have[^\.]{0,100} (?:appointments?|follow(?:-| ||)up) with ', 94 | r'see[^\.]{0,100} (?:appointments?|follow(?:-| ||)up) below', 95 | r'provide[^\.]{0,100} phone number', 96 | r'getting an appointment for you'], 97 | # case manager 98 | "case manager sentences": [r'contact[^\.]{0,100} case manager', 99 | r'case manager[^\.]{0,100} (?:contact|call|in touch|give|arrange|schedule|make)'], 100 | # diet 101 | "diet headings": create_heading_rs(r'Diet') + create_heading_rs(r'Diet and [^\.]{4,20}'), 102 | # forward info 103 | "forward info sentences": [r'forward[^\.]{0,100} (?:information|info|paper(?: |-||)work)'], 104 | # instructions 105 | "instructions sentences": [r'Please (?:review|follow|check)[^\.]{1,100} instructions?', 106 | r'should discuss this further with '], 107 | # medication 108 | "medication headings": create_heading_rs(r'(?:medications?|medicines?|antibiotics?|pills?)') +\ 109 | create_heading_rs(r'(?:medications?|medicines?|antibiotics?|pills?) ?(?:changes|list|as follows|on discharge|for [^\.]{0,80})') +\ 110 | create_heading_rs(r'(?:take|administer|give|prescribe|order|direct|start|continue)[^\.]{0,100} doses') +\ 111 | create_heading_rs(r'schedule for[^\.]{0,100}') +\ 112 | [r'(?:take|administer|give|prescribe|order|direct|start|continue)[^\.]{0,50} (?:medications?|medicines?|antibiotics?|pills?)[^\.]{0,100} (?:prescribe|list|as follows)'], 113 | "medication sentences": [r'(?:following|not make|not make any|not make a|no) change[^\.]{0,100} (?:medications?|medicines?|antibiotics?|pills?)', 114 | r'(?:medications?|medicines?|antibiotics?|pills?)[^\.]{0,100} (?:prescribed|directed|ordered|listed below|change)', 115 | r'(?:continue|resume|take)[^\.]{0,100} (?:all|other|your)[^\.]{0,50} (?:medications?|medicines?|antibiotics?|pills?)', 116 | r'see[^\.]{0,100} list[^\.]{0,100} (?:medications?|medicines?|antibiotics?|pills?)', 117 | r'were given[^\.]{0,50} (?:presecription|prescription)'], 118 | "medication items": [r'^(?:please)? (?:start|stop|continue) take'], 119 | # questions 120 | "questions sentences": [r'call [^\.]{1,200} (?:questions|question|concerns|concern|before leave)', 121 | r'If [^\.]{1,200} (?:questions|question|concerns|concern)', 122 | r'Please do not hesitate to contact us'], 123 | # home 124 | "home sentences": [r'(?:ready|when|safe)[^\.]{0,30} home'], 125 | # surgery or procedure 126 | "surgery procedure headings": create_heading_rs(r'Surgery[^\.]{0,10}Procedure') + create_heading_rs(r'Surgery') + create_heading_rs(r'Procedure') + 127 | create_heading_rs(r'Your Surgery') + create_heading_rs(r'Your Procedure') + 128 | create_heading_rs(r'Recent Surgery') + create_heading_rs(r'Recent Procedure'), 129 | # warning signs 130 | "warning signs sentences": [r'please seek medical (?:care|attention)', 131 | r'to[^\.]{0,100} (?:ED(?:\.|,|;| )|ER(?:\.|,|;| )|Emergency Department|Emergency Room)', # Required seperation after ED/ER 132 | r'(?:call|contact|experience|develop) [^\.]{1,200} following', 133 | r'(?:call|contact)[^\.]{0,100} (?:develop|experience|concerning symptom|if weight|weight goes|doctor|physician|surgeon|provider|nurse|clinic|office|neurologist|cardiologist|hospital)', 134 | r'Please (?:call|contact|seek)[^\.]{0,200} if', 135 | r'If[^\.]{0,100} (?:develop|experience|concerning symptoms|worse)'], 136 | # wound care 137 | "wound care headings": create_heading_rs(r'Wound Care') + create_heading_rs(r'Wound Care Instructions?'), 138 | "wound care sentences": [r'GENERAL INSTRUCTIONS WOUND CARE You or a family member should inspect', 139 | r'GENERAL INSTRUCTIONS WOUND CARE\nYou or a family member should inspect', 140 | r'Please shower daily including washing incisions gently with mild soap[^\.]{0,10} no baths or swimming[^\.]{0,10} and look at your incisions', 141 | r'wash incisions gently with mild soap[^\.]{0,10} no baths or swimming[^\.]{0,10} look at your incisions daily', 142 | r'Do not smoke\. No pulling up, lifting more than 10 lbs\., or excessive bending or twisting\.', 143 | r'Have a friend/family member check your incision daily for signs of infection'], 144 | 145 | # general instructions 146 | "other headings": list(itertools.chain(*[create_heading_rs(h) for h in [r'Anticoagulation', r'Pain control', r'Prevena dressing instructions', 147 | r'your bowels', r'Dressings', r'Pain management', r'Incision care', 148 | r'What to expect', r'orthopaedic surgery', r'Physical Therapy', r'Treatment Frequency', 149 | r'IMPORTANT PATIENT DETAILS', r'IMPORTANT PATIENT DETAILS 1\.']])) +\ 150 | create_heading_rs(r'Please see below[^\.]{1,50} hospitalization') +\ 151 | create_heading_rs(r'[^\.]{0,50} in the hospital we') +\ 152 | [r'CRITICAL THAT YOU QUIT SMOKING'], 153 | 154 | # Added after t-sne analysis: Picked outlying clusters with at least 5 examples and 80% rouge4 performance 155 | # For each template selected 5 examples with very high rouge4 scores 156 | # Stroke template 157 | "stroke template sentences": [r'a condition (?:where|in which) a blood vessel providing oxygen and nutrients to the brain (?:is blocked|bleed)', 158 | r'The brain is the part of your body that controls? and directs all the other parts of your body', 159 | r'damage to the brain[^\.]{0,200} can result in a variety of symptoms', 160 | r'can have many different causes, so we assessed you for medical conditions', 161 | r'In order to prevent future strokes,? we plan to modify those risk factors'], 162 | "stone template sentences": [r'You can expect to see occasional blood in your urine and to possibly experience some urgency and frequency', 163 | r'You can expect to see blood in your urine for at least 1 week and to experience some pain with urination, urgency and frequency', 164 | r'The kidney stone may or may not [^\.]{0,30} AND\/or there may fragments\/others still in the process of passing', 165 | r'You may experiences? some pain associated with spasm? of your ureter'], 166 | "aortic graft template sentences": [r'You tolerated the procedure well and are now ready to be discharged from the hospital', 167 | r'Please follow the recommendations below to ensure a speedy and uneventful recovery', 168 | r'Division of Vascular and Endovascular Surgery[^\.]{0,200}please note'], 169 | "caotic endarterectomy template sentences": [r'You tolerated the procedure well and are now ready to be discharged from the hospital', 170 | r'You are doing well and are now ready to be discharged from the hospital', 171 | r'Please follow the recommendations below to ensure a speedy and uneventful recovery'], 172 | "neck surgery template sentences": [r'Rest is important and will help you feel better\. Walking is also important\. It will help prevent problems'], 173 | "TAVR template sentences": [r'If you stop these medications or miss[^\.]{0,30}, you risk causing a blood clot forming on your new valve', 174 | r'These medications help to prevent blood clots from forming on the new valve'], 175 | "appendicitis template sentences": [r' preparing for discharge home with the following instructions'], 176 | "bowel obstruction template sentences": [r'You may return home to finish your recovery\. Please monitor' 177 | r'may or may not have had a bowel movement prior to[^\.]{0,20} discharge which is acceptable[^\.]{0,5} however it is important that[^\.]{0,30} have a bowel movement in'], 178 | "small bowel obstruction template sentences": [r'You have tolerated a regular diet, are passing gas [^\.]{0,30} (?:not taking any pain medications|pain is controlled with pain medications by mouth)\.'], 179 | 180 | # These are the most general regexes, motivated by the fact that usually after a list no meaningful fluent text follows 181 | # Had to split because look before not compatible with variable length 182 | "general headings": [r'^\w' + re_heading_general + re_item_element, 183 | r'(?<=\. )' + re_heading_general + re_item_element], 184 | # When at least two items with the same delimiter cut from thereon 185 | "at least two items": [r'^(?:' + item + r'(?:[^\n]+\n){1,2}\n?){2,}' for item in ITEMIZE_ELEMENTS], 186 | } 187 | 188 | def create_delimiter_regex(delimiter_list): 189 | # Be careful with wildcards because include newlines (so no dotall flag) 190 | return re.compile('|'.join(delimiter_list), re.IGNORECASE|re.MULTILINE) 191 | RE_SUFFIXES_DICT = {delimiter_name: create_delimiter_regex(delimiter_list) for delimiter_name, delimiter_list in SUFFIXES_DICT.items()} 192 | --------------------------------------------------------------------------------