├── src ├── retriever │ ├── ANCE │ │ ├── embedding │ │ │ └── embedding │ │ ├── output │ │ │ └── ance.best.pt │ │ ├── retrieval │ │ │ └── retrieved_reports.json │ │ ├── train.sh │ │ ├── gen_embedings.sh │ │ ├── utils.py │ │ ├── multi_model.py │ │ ├── data.py │ │ ├── gen_embeddings.py │ │ └── train.py │ ├── DPR │ │ ├── embedding │ │ │ └── embedding │ │ ├── output │ │ │ └── dpr.best.pt │ │ ├── retrieval │ │ │ └── retrieved_reports.json │ │ ├── retrieve.sh │ │ ├── train.sh │ │ ├── gen_hard_negatives.sh │ │ ├── utils.py │ │ ├── gen_embedings.sh │ │ ├── retrieval.py │ │ ├── gen_hard_negatives.py │ │ ├── data.py │ │ ├── multi_model.py │ │ ├── gen_embeddings.py │ │ ├── evaluate_retriever.py │ │ └── train.py │ └── checkpoint │ │ └── model.best.pt ├── evaluation.sh ├── generator │ ├── llava_json_to_evaluation_file.py │ ├── evaluate_llava.sh │ ├── vqa │ │ ├── eval_llava_vqa.sh │ │ ├── inference_llava_vqa.sh │ │ └── train_llava_vqa.sh │ ├── knn_index_to_evaluation_file.py │ ├── knn_ideal.py │ ├── convert_json_or_jsonl.py │ ├── inference_llava.sh │ ├── train_llava.sh │ ├── build_nonrag_dataset.py │ ├── build_rag_dataset.py │ └── knn.py └── evaluation.py ├── assets └── overview.png ├── data ├── chexpert │ └── test.json ├── mimic │ ├── test.json │ ├── train.json │ └── valid.json ├── factual_mining │ ├── build_pos_train │ │ ├── gen_similarity.sh │ │ ├── merge_topk_pos.sh │ │ ├── gen_topk_oracle_train.sh │ │ ├── gen_topk_pos.sh │ │ ├── merge_topk_pos.py │ │ ├── gen_similarity.py │ │ ├── utils.py │ │ └── gen_topk_pos.py │ └── build_pos_valid │ │ ├── gen_similarity.sh │ │ ├── gen_topk_pos.sh │ │ ├── gen_similarity.py │ │ ├── utils.py │ │ └── gen_topk_pos.py ├── parse.py └── label.py ├── install_llava.sh ├── LICENSE ├── requirements.txt ├── .gitignore └── README.md /src/retriever/ANCE/embedding/embedding: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/ANCE/output/ance.best.pt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/DPR/embedding/embedding: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/DPR/output/dpr.best.pt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/checkpoint/model.best.pt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/ANCE/retrieval/retrieved_reports.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retriever/DPR/retrieval/retrieved_reports.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxcscmu/FactMM-RAG/HEAD/assets/overview.png -------------------------------------------------------------------------------- /data/chexpert/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "image":["path/to/radiology/frontal.jpg","path/to/radiology/lateral.jpg"], 4 | "finding":"Frontal...", 5 | "impression":"No acute...." 6 | } 7 | ] -------------------------------------------------------------------------------- /data/mimic/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "image":["path/to/radiology/frontal.jpg","path/to/radiology/lateral.jpg"], 4 | "finding":"Frontal...", 5 | "impression":"No acute...." 6 | } 7 | ] -------------------------------------------------------------------------------- /data/mimic/train.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "image":["path/to/radiology/frontal.jpg","path/to/radiology/lateral.jpg"], 4 | "finding":"Frontal...", 5 | "impression":"No acute...." 6 | } 7 | ] -------------------------------------------------------------------------------- /data/mimic/valid.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "image":["path/to/radiology/frontal.jpg","path/to/radiology/lateral.jpg"], 4 | "finding":"Frontal...", 5 | "impression":"No acute...." 6 | } 7 | ] -------------------------------------------------------------------------------- /src/evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default values 4 | REF_PATH="/FactMM-RAG/data/mimic/test.json" 5 | PRED_PATH="/FactMM-RAG/data/mimic/test_generated.json" 6 | DEVICE="cuda" 7 | RADGRAPH_LEVEL="partial" 8 | BERT_MODEL="distilbert-base-uncased" 9 | 10 | # Run Python evaluation script 11 | python eval.py --ref_path $REF_PATH \ 12 | --pred_path $PRED_PATH \ 13 | --device $DEVICE \ 14 | --radgraph_level $RADGRAPH_LEVEL \ 15 | --bert_model $BERT_MODEL 16 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/gen_similarity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="gen_sim_train" 3 | #SBATCH -o %x-%a.out 4 | #SBATCH -e %x-%a.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=1 9 | #SBATCH --mem=16G 10 | #SBATCH --time=03:00:00 11 | #SBATCH --array=0-63 12 | 13 | python3 gen_similarity.py \ 14 | --train_data_file "/FactMM-RAG/data/mimic/train_labeled.json" \ 15 | --output_folder "/FactMM-RAG/data/mimic/scoring_chunks_train" \ 16 | --num_chunks 64 \ 17 | --chunk_id $SLURM_ARRAY_TASK_ID 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/generator/llava_json_to_evaluation_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--file", type=str) 7 | args = vars(parser.parse_args()) 8 | file = args["file"] 9 | base, ext = os.path.splitext(file) 10 | ext = ext.lstrip(".") 11 | assert ext == "json" 12 | write_file = f"{base}_inference.{ext}" 13 | 14 | with open(file, "r") as f: 15 | obj = json.load(f) 16 | with open(write_file, "w") as f: 17 | out = [ 18 | { 19 | "retrieved_finding": [v["text"]] 20 | } for v in obj 21 | ] 22 | json.dump(out, f) -------------------------------------------------------------------------------- /data/factual_mining/build_pos_valid/gen_similarity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="gen_sim_valid" 3 | #SBATCH -o %x-%a.out 4 | #SBATCH -e %x-%a.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=1 9 | #SBATCH --mem=16G 10 | #SBATCH --time=03:00:00 11 | #SBATCH --array=0-4 12 | 13 | python3 gen_similarity_valid.py \ 14 | --query_data_file "FactMM-RAG/data/mimic/valid_labeled.json" \ 15 | --corpus_data_file "FactMM-RAG/data/mimic/train_labeled.json" \ 16 | --output_folder "FactMM-RAG/data/mimic/scoring_chunks_valid" \ 17 | --num_chunks 4 \ 18 | --chunk_id $SLURM_ARRAY_TASK_ID 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/retriever/DPR/retrieve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set paths 4 | TEST_IMAGE_EMBEDDING_PATH="/FactMM-RAG/DPR/embedding/test_embedding_image.pkl" 5 | TRAIN_JSON_PATH="/FactMM-RAG/data/mimic/train.json" 6 | TRAIN_EMBEDDING_PATH="/FactMM-RAG/DPR/embedding/train_embedding_finding.pkl" 7 | OUTPUT_JSON_PATH="/FactMM-RAG/src/retriever/DPR/retrieval/retrieved_reports.json" 8 | 9 | # Run Python script 10 | python retrieve.py --test_image_embedding_path $TEST_IMAGE_EMBEDDING_PATH \ 11 | --train_json_path $TRAIN_JSON_PATH \ 12 | --train_embedding_path $TRAIN_EMBEDDING_PATH \ 13 | --output_json_path $OUTPUT_JSON_PATH 14 | -------------------------------------------------------------------------------- /install_llava.sh: -------------------------------------------------------------------------------- 1 | cd FactMM-RAG 2 | git clone https://github.com/haotian-liu/LLaVA.git 3 | cd LLaVA 4 | 5 | # Follow install instructions from 6 | # https://github.com/haotian-liu/LLaVA.git 7 | # at the current time, commit hash c121f0432da27facab705978f83c4ada465e46fd 8 | conda create -n llava python=3.10 -y 9 | conda activate llava 10 | pip install --upgrade pip 11 | pip install -e . 12 | 13 | pip install -e ".[train]" 14 | pip install flash-attn --no-build-isolation 15 | # If flash-attn gets 404 error, set environment variable export `HUGGINGFACE_CO_TIMEOUT=60` 16 | 17 | 18 | # some additional upgrades 19 | pip install transformers[torch] 20 | pip install peft==0.10.0 transformers==4.36.2 accelerate==0.21.0 tokenizers==0.15.1 -------------------------------------------------------------------------------- /src/generator/evaluate_llava.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default values 4 | REF_PATH="./data/mimic/test.json" 5 | PRED_PATH="./data/rag/llava_output/test/merge_test_eval.json" 6 | PRED_PATH_INF="./data/rag/llava_output/test/merge_test_eval_inference.json" && 7 | DEVICE="cuda" 8 | RADGRAPH_LEVEL="partial" 9 | BERT_MODEL="distilbert-base-uncased" 10 | 11 | # Run Python evaluation script 12 | python src/generator/convert_json_or_jsonl.py --file ${PRED_PATH}l --overwrite && 13 | python src/generator/llava_json_to_evaluation_file.py --file ${PRED_PATH} && 14 | python src/evaluation.py --ref_path $REF_PATH \ 15 | --pred_path $PRED_PATH_INF \ 16 | --device $DEVICE \ 17 | --radgraph_level $RADGRAPH_LEVEL \ 18 | --bert_model $BERT_MODEL 19 | -------------------------------------------------------------------------------- /src/generator/vqa/eval_llava_vqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate FactMM-RAG 4 | 5 | # Set default values 6 | REF_PATH="./data/mimic/test.json" && 7 | PRED_PATH="./data/rag/vqa/llava_output/test/merge_test_eval.json" && 8 | PRED_PATH_INF="./data/rag/vqa/llava_output/test/merge_test_eval_inference.json" && 9 | DEVICE="cuda" && 10 | RADGRAPH_LEVEL="partial" && 11 | BERT_MODEL="distilbert-base-uncased" && 12 | python src/generator/convert_json_or_jsonl.py --file ${PRED_PATH}l --overwrite && 13 | python src/generator/llava_json_to_evaluation_file.py --file ${PRED_PATH} && 14 | python src/evaluation.py --ref_path $REF_PATH \ 15 | --pred_path $PRED_PATH_INF \ 16 | --device $DEVICE \ 17 | --radgraph_level $RADGRAPH_LEVEL \ 18 | --bert_model $BERT_MODEL -------------------------------------------------------------------------------- /src/retriever/ANCE/train.sh: -------------------------------------------------------------------------------- 1 | 2 | chex_thresh=1.0 3 | top_k=3 4 | radg_thresh=0.4 5 | 6 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 7 | 8 | python train.py --out_path /FactMM-RAG/src/retriever/ANCE/output/ance.best.pt \ 9 | --train_path /FactMM-RAG/data/mimic/train.json \ 10 | --valid_path /FactMM-RAG/data/mimic/valid.json \ 11 | --train_pos_path /FactMM-RAG/data/mimic/scoring_chunks_train/$expname/reduction.pkl \ 12 | --valid_pos_path /FactMM-RAG/data/mimic/scoring_chunks_valid/$expname/positive_list.pkl \ 13 | --train_neg_path /FactMM-RAG/src/retriever/ANCE/train_hard_negatives.pkl \ 14 | --valid_neg_path /FactMM-RAG/src/retriever/ANCE/valid_hard_negatives.pkl \ 15 | --pretrained_model_path /FactMM-RAG/src/retriever/DPR/output/dpr.best.pt \ 16 | --wandb_name "Finetuning with positives from exhaustive with reranked top_$top_k chexbert $chex_thresh radgraph $radg_thresh ance" 17 | # --freeze_vision_model 18 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/merge_topk_pos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="merge_pos_train.sh" 3 | #SBATCH -o %x.out 4 | #SBATCH -e %x.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=4 9 | #SBATCH --mem=24G 10 | #SBATCH --time=02:00:00 11 | 12 | 13 | top_k=3 14 | tr_chunks=64 15 | chex_thresh=1.0 16 | radg_thresh=0.4 17 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 18 | 19 | echo "=== Reduce ${expname} (tr) ===" 20 | echo "$SLURM_NODELIST" && 21 | 22 | start=$(date +%s) 23 | 24 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 25 | echo "=== topk: $topk, radg_thresh: $radg_thresh ===" 26 | 27 | python merge_topk_pos.py \ 28 | --from_folder "/FactMM-RAG/data/mimic/scoring_chunks_train/$expname" \ 29 | --file_prefix "chunk_{i}.pkl" \ 30 | --num_chunks $tr_chunks 31 | 32 | end=$(date +%s) && 33 | runtime=$((end-start)) && 34 | echo "Time Taken: $runtime s" -------------------------------------------------------------------------------- /src/retriever/DPR/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=healthcare 3 | #SBATCH --output=healthcare_1.0_0.4_clueweb.out 4 | #SBATCH --error=healthcare_1.0_0.4_clueweb.err 5 | #SBATCH --partition=general 6 | #SBATCH --cpus-per-task=8 7 | #SBATCH --gres=gpu:A6000:1 8 | #SBATCH --mem=60G 9 | #SBATCH --time=1-00:00:00 10 | 11 | 12 | chex_thresh=1.0 13 | top_k=3 14 | radg_thresh=0.4 15 | 16 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 17 | 18 | python train.py --out_path /FactMM-RAG/src/retriever/output/dpr.best.pt \ 19 | --train_path /FactMM-RAG/data/mimic/train.json \ 20 | --valid_path /FactMM-RAG/data/mimic/valid.json \ 21 | --train_pos_path /FactMM-RAG/data/mimic/scoring_chunks_train/$expname/reduction.pkl \ 22 | --valid_pos_path /FactMM-RAG/data/mimic/scoring_chunks_valid/$expname/positive_list.pkl \ 23 | --wandb_name "Finetuning with positives from exhaustive with reranked top_$top_k chexbert $chex_thresh radgraph $radg_thresh" \ 24 | --pretrained_model_path /FactMM-RAG/src/retriever/checkpoint/model.best.pt 25 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_valid/gen_topk_pos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="gen_merge_pos_valid.sh" 3 | #SBATCH -o %x.out 4 | #SBATCH -e %x.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=4 9 | #SBATCH --mem=16G 10 | #SBATCH --time=02:00:00 11 | 12 | va_chunks=4 13 | chex_thresh=1.0 14 | top_k=3 15 | radg_thresh=0.4 16 | 17 | echo "$SLURM_NODELIST" && 18 | 19 | start=$(date +%s) 20 | 21 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 22 | echo "=== topk: $topk, radg_thresh: $radg_thresh ===" 23 | python gen_topk_pos.py \ 24 | --from_folder "FactMM-RAG/data/mimic/scoring_chunks_valid" \ 25 | --do_chex \ 26 | --do_radg \ 27 | --num_chunks $va_chunks \ 28 | --n 991 \ 29 | --output_folder "FactMM-RAG/data/mimic/scoring_chunks_valid/$expname" \ 30 | --pre_mask \ 31 | --pre_mask_chex $chex_thresh \ 32 | --pre_mask_radg $radg_thresh \ 33 | --top_k $top_k 34 | 35 | end=$(date +%s) && 36 | runtime=$((end-start)) && 37 | echo "Time Taken: $runtime s" 38 | -------------------------------------------------------------------------------- /src/generator/knn_index_to_evaluation_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import os 5 | import numpy 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--train_data_file", type=str, default="./data/mimic/train.json") 9 | parser.add_argument("--knn_index", type=str, required=True) 10 | parser.add_argument("--output_file", type=str, required=True) 11 | parser.add_argument("--is_test", action="store_true") 12 | args = vars(parser.parse_args()) 13 | 14 | train_data_file = args["train_data_file"] 15 | knn_index, output_file = args["knn_index"], args["output_file"] 16 | is_test = args["is_test"] 17 | 18 | with open(train_data_file, "r") as tr, open(knn_index, "rb") as knn: 19 | train_data = json.load(tr) 20 | knn_data = pickle.load(knn) 21 | 22 | output = [ 23 | { 24 | "retrieved_finding": [train_data[knn_data_i["knn_index"][0]]["finding"]] 25 | } for knn_data_i in knn_data 26 | ] 27 | 28 | 29 | with open(output_file, "w") as f: 30 | json.dump(output, f, indent=2) 31 | print(f"Writing {len(output)} to {output_file}") -------------------------------------------------------------------------------- /src/retriever/DPR/gen_hard_negatives.sh: -------------------------------------------------------------------------------- 1 | python3 gen_hard_negatives.py --query_embed_path /FactMM-RAG/DPR/embedding/train_embedding_image.pkl\ 2 | --txt_embed_path /FactMM-RAG/DPR/embedding/train_embedding_finding.pkl \ 3 | --result_path /FactMM-RAG/src/retriever/ANCE/train_hard_negatives.pkl \ 4 | --query_path /home/liwens/healthcare/Lightning-Pretrain/chest/data/mimic/train_labeled.json \ 5 | --corpus_path /home/liwens/healthcare/Lightning-Pretrain/chest/data/mimic/train_labeled.json \ 6 | --chexbert_threshold 1 \ 7 | --radgraph_threshold 0.4 \ 8 | --topN 100 9 | 10 | python3 gen_hard_negatives.py --query_embed_path /FactMM-RAG/DPR/embedding/valid_embedding_image.pkl \ 11 | --txt_embed_path /FactMM-RAG/DPR/embedding/train_embedding_finding.pkl \ 12 | --result_path /FactMM-RAG/src/retriever/ANCE/valid_hard_negatives.pkl \ 13 | --query_path /home/liwens/healthcare/Lightning-Pretrain/chest/data/mimic/valid_labeled.json \ 14 | --corpus_path /home/liwens/healthcare/Lightning-Pretrain/chest/data/mimic/train_labeled.json \ 15 | --chexbert_threshold 1 \ 16 | --radgraph_threshold 0.4 \ 17 | --topN 100 -------------------------------------------------------------------------------- /src/generator/knn_ideal.py: -------------------------------------------------------------------------------- 1 | import argparse, json, os, pickle, numpy as np 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument( 5 | "--folder", type=str, default="./data/mimic/scoring_chunks_train/top_30_c0.0_r0.0" 6 | ) 7 | parser.add_argument( 8 | "--output_file", 9 | type=str, 10 | default="./data/rag/2025_03_09_end_to_end/knn_tr2tr_ideal.pkl", 11 | ) 12 | parser.add_argument("--num_chunks", type=int, default=64) 13 | args = vars(parser.parse_args()) 14 | 15 | folder, output_file = args["folder"], args["output_file"] 16 | num_chunks = args["num_chunks"] 17 | 18 | out = [] 19 | 20 | i = 0 21 | for chunk in range(num_chunks): 22 | file = os.path.join(folder, f"chunk_{chunk}.pkl") 23 | with open(file, "rb") as f: 24 | obj = pickle.load(f) 25 | for row in obj["positive_list"]: 26 | if len(row) == 0: 27 | print(f"Missing on {i=}") 28 | out.append({"key": i, "knn_index": row}) 29 | i += 1 30 | print(len(out)) 31 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 32 | with open(output_file, "wb") as f: 33 | pickle.dump(out, f) 34 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/gen_topk_oracle_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="gen_topk_oracle_train.sh" 3 | #SBATCH -o ./tr_reports/%x-%a.out 4 | #SBATCH -e ./tr_reports/%x-%a.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --mem=12G 10 | #SBATCH --time=04:00:00 11 | #SBATCH --array=0-63%8 12 | 13 | # Note thresholds are 0.0 for exhaustive training-time search 14 | tr_chunks=64 15 | chex_thresh=0.0 16 | top_k=30 17 | radg_thresh=0.0 18 | 19 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 20 | echo "=== topk: $top_k, radg_thresh: $radg_thresh ===" 21 | python ./data/factual_mining/build_pos_train/gen_topk_pos.py \ 22 | --from_folder "./data/mimic/scoring_chunks_train" \ 23 | --do_chex \ 24 | --do_radg \ 25 | --num_chunks $tr_chunks \ 26 | --n 125417 \ 27 | --chunk_id $SLURM_ARRAY_TASK_ID \ 28 | --output_file "./data/mimic/scoring_chunks_train/$expname/chunk_$SLURM_ARRAY_TASK_ID.pkl" \ 29 | --pre_mask \ 30 | --pre_mask_chex $chex_thresh \ 31 | --pre_mask_radg $radg_thresh \ 32 | --top_k $top_k -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 cxcscmu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/gen_topk_pos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name="gen_pos_train.sh" 3 | #SBATCH -o ./tr_reports/%x-%a.out 4 | #SBATCH -e ./tr_reports/%x-%a.err 5 | #SBATCH --partition=general 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --mem=12G 10 | #SBATCH --time=04:00:00 11 | #SBATCH --array=0-63%8 12 | 13 | tr_chunks=64 14 | chex_thresh=1.0 15 | top_k=3 16 | radg_thresh=0.4 17 | 18 | start=$(date +%s) 19 | expname=top_${top_k}_c${chex_thresh}_r${radg_thresh} 20 | echo "=== topk: $top_k, radg_thresh: $radg_thresh ===" 21 | python gen_topk_pos.py \ 22 | --from_folder "./data/mimic/scoring_chunks_train" \ 23 | --do_chex \ 24 | --do_radg \ 25 | --skip_bad_sample \ 26 | --num_chunks $tr_chunks \ 27 | --n 125417 \ 28 | --chunk_id $SLURM_ARRAY_TASK_ID \ 29 | --output_file "./data/mimic/scoring_chunks_train/$expname/chunk_$SLURM_ARRAY_TASK_ID.pkl" \ 30 | --pre_mask \ 31 | --pre_mask_chex $chex_thresh \ 32 | --pre_mask_radg $radg_thresh \ 33 | --top_k $top_k 34 | 35 | end=$(date +%s) && 36 | runtime=$((end-start)) && 37 | echo "Time Taken: $runtime s" -------------------------------------------------------------------------------- /src/generator/convert_json_or_jsonl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--file", type=str) 7 | parser.add_argument("--overwrite", action="store_true") 8 | args = vars(parser.parse_args()) 9 | 10 | file = args["file"] 11 | overwrite = args["overwrite"] 12 | base, ext = os.path.splitext(file) 13 | ext = ext.lstrip(".") 14 | assert ext in ("json", "jsonl"), f"{base=} {ext=} Extension is not 'json' or 'jsonl'" 15 | 16 | alt_ext = "jsonl" if ext == "json" else "json" 17 | alt_path = f"{base}.{alt_ext}" 18 | if not overwrite and os.path.exists(alt_path): 19 | raise Exception(f"Path {alt_path} already exists! Overwrite with --overwrite flag") 20 | 21 | with open(file, "r") as f: 22 | print(f"Converting .{ext} to .{alt_ext}") 23 | if ext == "json": 24 | obj = json.load(f) 25 | with open(alt_path, "w") as g: 26 | g.write("\n".join([json.dumps(v) for v in obj])) 27 | else: 28 | obj = [json.loads(l) for line in f.readlines() if (l := line.strip())] 29 | with open(alt_path, "w") as g: 30 | json.dump(obj, g, indent=2) -------------------------------------------------------------------------------- /src/generator/inference_llava.sh: -------------------------------------------------------------------------------- 1 | cd LLaVA 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate llava 4 | 5 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 6 | IFS=',' read -ra GPULIST <<< "$gpu_list" 7 | TEST_PATH=../data/rag 8 | CKPT=llava_output 9 | SPLIT="test" 10 | CHUNKS=${#GPULIST[@]} 11 | 12 | Q_FILE=../data/rag/llava_data_te.jsonl 13 | OUTPUT_FILE=$TEST_PATH/$CKPT/$SPLIT/merge_test_eval.jsonl 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | mkdir -p $TEST_PATH/$CKPT/$SPLIT 17 | touch $TEST_PATH/$CKPT/$SPLIT/${CHUNKS}_${IDX}.jsonl 18 | done 19 | 20 | for IDX in $(seq 0 $((CHUNKS-1))); do 21 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 22 | --model-path $TEST_PATH/$CKPT \ 23 | --question-file $Q_FILE \ 24 | --image-folder $IMAGE_FOLDER \ 25 | --answers-file $TEST_PATH/$CKPT/$SPLIT/$CHUNKS_$IDX.jsonl \ 26 | --num-chunks $CHUNKS \ 27 | --chunk-idx $IDX \ 28 | --temperature 0 \ 29 | --conv-mode vicuna_v1 & 30 | done 31 | wait 32 | 33 | # Clear out the output file if it exists. 34 | > "$OUTPUT_FILE" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat $TEST_PATH/$CKPT/$SPLIT/$CHUNKS_$IDX.jsonl >> "$OUTPUT_FILE" 39 | done -------------------------------------------------------------------------------- /src/generator/vqa/inference_llava_vqa.sh: -------------------------------------------------------------------------------- 1 | cd LLaVA 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate llava 4 | 5 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 6 | IFS=',' read -ra GPULIST <<< "$gpu_list" 7 | TEST_PATH=../data/rag/vqa 8 | CKPT=llava_output 9 | SPLIT="test" 10 | CHUNKS=${#GPULIST[@]} 11 | 12 | Q_FILE=../data/rag/vqa/llava_vqa_test.jsonl 13 | OUTPUT_FILE=$TEST_PATH/$CKPT/$SPLIT/merge_test_eval.jsonl 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | mkdir -p $TEST_PATH/$CKPT/$SPLIT 17 | touch $TEST_PATH/$CKPT/$SPLIT/${CHUNKS}_${IDX}.jsonl 18 | done 19 | 20 | for IDX in $(seq 0 $((CHUNKS-1))); do 21 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 22 | --model-path $TEST_PATH/$CKPT \ 23 | --question-file $Q_FILE \ 24 | --image-folder $IMAGE_FOLDER \ 25 | --answers-file $TEST_PATH/$CKPT/$SPLIT/$CHUNKS_$IDX.jsonl \ 26 | --num-chunks $CHUNKS \ 27 | --chunk-idx $IDX \ 28 | --temperature 0 \ 29 | --conv-mode vicuna_v1 & 30 | done 31 | wait 32 | 33 | # Clear out the output file if it exists. 34 | > "$OUTPUT_FILE" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat $TEST_PATH/$CKPT/$SPLIT/$CHUNKS_$IDX.jsonl >> "$OUTPUT_FILE" 39 | done -------------------------------------------------------------------------------- /src/retriever/DPR/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5ForConditionalGeneration, CLIPProcessor, T5Model,CLIPVisionModel 2 | from multi_model import MultiModal 3 | import numpy as np 4 | DEFAULT_IMAGE_PATCH_TOKEN = "" 5 | DEFAULT_IM_START_TOKEN = "" 6 | DEFAULT_IM_END_TOKEN = "" 7 | 8 | 9 | def load_model(args,device): 10 | clip_model_name=args.clip_model_name 11 | t5_model_name=args.t5_model_name 12 | t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name) 13 | t5_model = T5Model.from_pretrained(t5_model_name) 14 | t5_tokenizer.add_special_tokens( 15 | {"additional_special_tokens": [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]}) 16 | t5_model.resize_token_embeddings(len(t5_tokenizer)) 17 | image_processor = CLIPProcessor.from_pretrained(clip_model_name) 18 | model = MultiModal(clip_model_name, t5_model, t5_tokenizer) 19 | model = model.to(device) 20 | return t5_tokenizer, model, image_processor 21 | 22 | def get_img_patch_token_size(clip_model_name): 23 | clip_model = CLIPVisionModel.from_pretrained(clip_model_name) 24 | image_size=clip_model.config.image_size 25 | patch_size=clip_model.config.patch_size 26 | img_patch_token_size=int(image_size/patch_size)**2 27 | return img_patch_token_size 28 | -------------------------------------------------------------------------------- /src/generator/train_llava.sh: -------------------------------------------------------------------------------- 1 | cd LLaVA && 2 | source ~/miniconda3/etc/profile.d/conda.sh && 3 | conda activate llava && 4 | deepspeed --master_port 11234 llava/train/train_mem.py \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path lmsys/vicuna-7b-v1.5 \ 7 | --version v1 \ 8 | --data_path ../data/rag/llava_data_tr.json \ 9 | --image_folder $IMAGE_FOLDER \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --pretrain_mm_mlp_adapter $PROJECTOR_PATH \ 12 | --mm_projector_type mlp2x_gelu \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --image_aspect_ratio pad \ 17 | --group_by_modality_length False \ 18 | --bf16 True \ 19 | --output_dir ../data/rag/llava_output \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 4 \ 22 | --per_device_eval_batch_size 8 \ 23 | --gradient_accumulation_steps 4 \ 24 | --save_strategy "steps" \ 25 | --save_steps 250 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True -------------------------------------------------------------------------------- /src/retriever/DPR/gen_embedings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default values for arguments 4 | T5_MODEL_NAME="OpenMatch/t5-ance" 5 | CLIP_MODEL_NAME="openai/clip-vit-base-patch32" 6 | SAVED_CKPT="/FactMM-RAG/src/retriever/output/dpr.best.pt" 7 | 8 | TRAIN_PATH="/FactMM-RAG/data/mimic/train.json" 9 | VALID_PATH="/FactMM-RAG/data/mimic/valid.json" 10 | TEST_PATH="/FactMM-RAG/data/mimic/test.json" 11 | 12 | OUTPUT_TRAIN_IMAGE_PATH="/FactMM-RAG/DPR/embedding/train_embedding_image.pkl" 13 | OUTPUT_TRAIN_FINDING_PATH="/FactMM-RAG/DPR/embedding/train_embedding_finding.pkl" 14 | OUTPUT_VALID_IMAGE_PATH="/FactMM-RAG/DPR/embedding/valid_embedding_image.pkl" 15 | OUTPUT_TEST_IMAGE_PATH="/FactMM-RAG/DPR/embedding/test_embedding_image.pkl" 16 | 17 | # Run Python script with arguments 18 | python gen_embeddings.py --t5_model_name $T5_MODEL_NAME \ 19 | --clip_model_name $CLIP_MODEL_NAME \ 20 | --saved_ckpt $SAVED_CKPT \ 21 | --train_path $TRAIN_PATH \ 22 | --valid_path $VALID_PATH \ 23 | --test_path $TEST_PATH \ 24 | --output_train_image_path $OUTPUT_TRAIN_IMAGE_PATH \ 25 | --output_train_finding_path $OUTPUT_TRAIN_FINDING_PATH \ 26 | --output_valid_image_path $OUTPUT_VALID_IMAGE_PATH \ 27 | --output_test_image_path $OUTPUT_TEST_IMAGE_PATH 28 | -------------------------------------------------------------------------------- /src/generator/vqa/train_llava_vqa.sh: -------------------------------------------------------------------------------- 1 | cd LLaVA && 2 | source ~/miniconda3/etc/profile.d/conda.sh && 3 | conda activate llava && 4 | deepspeed --master_port 11234 llava/train/train_mem.py \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path lmsys/vicuna-7b-v1.5 \ 7 | --version v1 \ 8 | --data_path ../data/rag/vqa/llava_vqa_train.json \ 9 | --image_folder $IMAGE_FOLDER \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --pretrain_mm_mlp_adapter $PROJECTOR_PATH \ 12 | --mm_projector_type mlp2x_gelu \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --image_aspect_ratio pad \ 17 | --group_by_modality_length False \ 18 | --bf16 True \ 19 | --output_dir ../data/rag/vqa/llava_output \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 4 \ 22 | --per_device_eval_batch_size 8 \ 23 | --gradient_accumulation_steps 4 \ 24 | --save_strategy "steps" \ 25 | --save_steps 99999 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True -------------------------------------------------------------------------------- /src/retriever/ANCE/gen_embedings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default values for arguments 4 | T5_MODEL_NAME="OpenMatch/t5-ance" 5 | CLIP_MODEL_NAME="openai/clip-vit-base-patch32" 6 | SAVED_CKPT="/FactMM-RAG/src/retriever/output/ance.best.pt" 7 | 8 | TRAIN_PATH="/FactMM-RAG/data/mimic/train.json" 9 | VALID_PATH="/FactMM-RAG/data/mimic/valid.json" 10 | TEST_PATH="/FactMM-RAG/data/mimic/test.json" 11 | 12 | OUTPUT_TRAIN_IMAGE_PATH="/FactMM-RAG/ANCE/embedding/train_embedding_image.pkl" 13 | OUTPUT_TRAIN_FINDING_PATH="/FactMM-RAG/ANCE/embedding/train_embedding_finding.pkl" 14 | OUTPUT_VALID_IMAGE_PATH="/FactMM-RAG/ANCE/embedding/valid_embedding_image.pkl" 15 | OUTPUT_TEST_IMAGE_PATH="/FactMM-RAG/ANCE/embedding/test_embedding_image.pkl" 16 | 17 | # Run Python script with arguments 18 | python gen_embeddings.py --t5_model_name $T5_MODEL_NAME \ 19 | --clip_model_name $CLIP_MODEL_NAME \ 20 | --saved_ckpt $SAVED_CKPT \ 21 | --train_path $TRAIN_PATH \ 22 | --valid_path $VALID_PATH \ 23 | --test_path $TEST_PATH \ 24 | --image_folder $IMAGE_FOLDER \ 25 | --output_train_image_path $OUTPUT_TRAIN_IMAGE_PATH \ 26 | --output_train_finding_path $OUTPUT_TRAIN_FINDING_PATH \ 27 | --output_valid_image_path $OUTPUT_VALID_IMAGE_PATH \ 28 | --output_test_image_path $OUTPUT_TEST_IMAGE_PATH 29 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/merge_topk_pos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.ma as ma 3 | import argparse 4 | import json 5 | import math 6 | import tqdm 7 | import os 8 | import time 9 | import pickle 10 | from collections import defaultdict 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--from_folder", type=str, required=True) 14 | parser.add_argument("--file_prefix", type=str, required=True) 15 | parser.add_argument("--num_chunks", type=int, required=True) 16 | parser.add_argument("--output_file", type=str, default=None) 17 | parser.add_argument("--add_back_self", action="store_true") 18 | args = parser.parse_args() 19 | print(vars(args), flush=True) 20 | 21 | statistics = defaultdict(float) 22 | positive_list = [] 23 | bad_list = [] 24 | 25 | for chunk_idx in tqdm.tqdm(range(args.num_chunks)): 26 | to_load = os.path.join(args.from_folder, args.file_prefix.format(i=chunk_idx)) 27 | with open(to_load, "rb") as f: 28 | obj = pickle.load(f) 29 | for k, v in obj["statistics"].items(): 30 | statistics[k] += v 31 | positive_list.extend(obj["positive_list"]) 32 | bad_list.extend(obj["bad_list"]) 33 | print(json.dumps(statistics, indent=4)) 34 | print(f"{len(positive_list)=}, {positive_list[:5]=}") 35 | 36 | output_file = args.output_file 37 | if output_file is None: 38 | output_file = os.path.join(args.from_folder, "reduction.pkl") 39 | print(f"Write to {output_file}", flush=True) 40 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 41 | with open(output_file, "wb") as f: 42 | pickle.dump(positive_list, f) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | accelerate==0.18.0 3 | boto3==1.28.62 4 | botocore==1.31.62 5 | charset-normalizer==3.1.0 6 | click==8.1.7 7 | clip==0.1.0 8 | cramjam==2.7.0 9 | cycler==0.11.0 10 | dpcpp-cpp-rt==2023.2.0 11 | faiss-cpu==1.10.0 12 | fastparquet==0.8.1 13 | filelock==3.12.0 14 | fonttools==4.38.0 15 | fsspec==2023.1.0 16 | ftfy==6.1.1 17 | huggingface-hub==0.14.1 18 | idna==3.4 19 | importlib-metadata==6.6.0 20 | intel-cmplr-lib-rt==2023.2.0 21 | intel-cmplr-lic-rt==2023.2.0 22 | intel-opencl-rt==2023.2.0 23 | intel-openmp==2023.2.0 24 | jmespath==1.0.1 25 | joblib==1.3.2 26 | kiwisolver==1.4.5 27 | matplotlib==3.5.3 28 | mkl==2023.2.0 29 | mkl-fft==1.3.11 30 | mkl-service==2.4.0 31 | nvidia-cublas-cu11==11.10.3.66 32 | nvidia-cuda-nvrtc-cu11==11.7.99 33 | nvidia-cuda-runtime-cu11==11.7.99 34 | nvidia-cudnn-cu11==8.5.0.96 35 | packaging==23.1 36 | Pillow==9.0.1 37 | protobuf==4.24.2 38 | psutil==5.9.5 39 | pyarrow==12.0.1 40 | pybase64==1.2.3 41 | pyparsing==3.1.1 42 | python-dateutil==2.8.2 43 | pytorch-transformers==1.0.0 44 | pytrec-eval==0.5 45 | pytz==2023.3.post1 46 | PyYAML==6.0 47 | regex==2022.10.31 48 | requests==2.29.0 49 | s3transfer==0.7.0 50 | sacremoses==0.0.53 51 | scikit-learn==1.0.2 52 | sentencepiece==0.1.99 53 | six==1.16.0 54 | tbb==2021.10.0 55 | tensorboardX==2.6.2.2 56 | threadpoolctl==3.1.0 57 | tokenizers==0.12.1 58 | torch==1.13.1 59 | torchaudio==0.13.1 60 | torchvision==0.14.1 61 | tqdm==4.65.0 62 | transformers==4.23.1 63 | typing_extensions==4.7.1 64 | urllib3==1.26.15 65 | wcwidth==0.2.6 66 | zipp==3.15.0 67 | radgraph==0.0.9 68 | wandb==0.16.4 69 | f1chexbert==0.0.2 -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/gen_similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | import math 5 | import tqdm 6 | import os 7 | import time 8 | 9 | from utils import chexbert_similarity, radgraph_similarity 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--train_data_file", type=str, 13 | default="FactMM-RAG/data/mimic/train_labeled.json") 14 | parser.add_argument("--output_folder", type=str, 15 | default="FactMM-RAG/data/mimic/scoring_chunks_train") 16 | parser.add_argument("--num_chunks", type=int, default=64) 17 | parser.add_argument("--chunk_id", type=int, required=True) 18 | args = parser.parse_args() 19 | 20 | with open(args.train_data_file, "r") as f: 21 | obj = json.load(f) 22 | os.makedirs(args.output_folder, exist_ok=True) 23 | n = len(obj) 24 | chunk_size = (n + args.num_chunks - 1) // args.num_chunks 25 | start = chunk_size * args.chunk_id 26 | end = min(chunk_size * (args.chunk_id + 1), n) 27 | print( 28 | f"Chunking({n=}) [{args.chunk_id}] ({args.chunk_id + 1}/{args.num_chunks}): [{start}, {end})", flush=True) 29 | 30 | out_chexbert_sims = np.zeros((end - start, n)) 31 | out_radgraph_sims = np.zeros((end - start, n)) 32 | a = time.time() 33 | for i in tqdm.tqdm(range(start, end), desc="outer"): 34 | query_i = obj[i] 35 | for j in range(n): 36 | doc_j = obj[j] 37 | chex_sim = chexbert_similarity(query_i, doc_j) 38 | radg_sim = radgraph_similarity(query_i, doc_j) 39 | out_chexbert_sims[i - start, j] = chex_sim 40 | out_radgraph_sims[i - start, j] = radg_sim 41 | b = time.time() 42 | print(f"Time Taken: {(b-a):0.4f} s", flush=True) 43 | print(f"Output: {out_chexbert_sims.shape=}, {out_radgraph_sims.shape=}") 44 | output_file = os.path.join(args.output_folder) 45 | print(f"Saving to: {args.output_folder}", flush=True) 46 | np.save(os.path.join(args.output_folder, 47 | f"chex_{args.chunk_id}"), out_chexbert_sims) 48 | np.save(os.path.join(args.output_folder, 49 | f"radg_{args.chunk_id}"), out_radgraph_sims) -------------------------------------------------------------------------------- /data/parse.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def process_split(image_paths_file, findings_file, impressions_file, output_json_file): 5 | # Read image paths and findings from the input files 6 | with open(image_paths_file, 'r') as f: 7 | image_path_lines = f.read().splitlines() 8 | 9 | with open(findings_file, 'r') as f: 10 | findings = f.read().splitlines() 11 | 12 | with open(impressions_file, 'r') as f: 13 | imps = f.read().splitlines() 14 | 15 | # Ensure that the number of image path lines and findings match 16 | if len(image_path_lines) != len(findings) or len(image_path_lines) != len(imps): 17 | print("Error: The number of image path lines, findings, and impressions doesn't match.") 18 | exit(1) 19 | 20 | all_entries = [] 21 | 22 | for i in range(len(image_path_lines)): 23 | image_paths = image_path_lines[i].split(',') # Split multiple image paths 24 | image_path = image_paths[0] 25 | entries = {"image": image_path.strip(), "finding": findings[i], "impression": imps[i]} 26 | all_entries.append(entries) 27 | 28 | # Write all entries to the single JSON file 29 | with open(output_json_file, 'w') as json_file: 30 | json.dump(all_entries, json_file, indent=4) 31 | 32 | print(f"Created {output_json_file} with all entries.") 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser(description="Process radiology reports into JSON format.") 37 | parser.add_argument("--image_paths_file", type=str, required=True, help="Path to the image paths file.") 38 | parser.add_argument("--findings_file", type=str, required=True, help="Path to the findings file.") 39 | parser.add_argument("--impressions_file", type=str, required=True, help="Path to the impressions file.") 40 | parser.add_argument("--output_json_file", type=str, required=True, help="Output JSON file path.") 41 | 42 | args = parser.parse_args() 43 | 44 | process_split(args.image_paths_file, args.findings_file, args.impressions_file, args.output_json_file) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /data/label.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | from radgraph import RadGraph 6 | from f1chexbert import F1CheXbert 7 | 8 | def process_labels(input_path, output_path, device="cuda"): 9 | # Load data 10 | with open(input_path, 'r') as f: 11 | data = json.load(f) 12 | 13 | radgraph = RadGraph() 14 | chexbert = F1CheXbert(device=device) 15 | 16 | # Defining classes 17 | target_names = [ 18 | "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", 19 | "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", 20 | "Fracture", "Support Devices", "No Finding" 21 | ] 22 | target_names_5 = ["Cardiomegaly", "Edema", "Consolidation", "Atelectasis", "Pleural Effusion"] 23 | target_names_5_index = np.where(np.isin(target_names, target_names_5))[0] 24 | 25 | labeled_data = [] 26 | for instance in tqdm(data, desc="Processing reports"): 27 | report = instance["finding"] 28 | annotations = radgraph([report]) 29 | label = chexbert.get_label(report) 30 | report = { 31 | "text": report 32 | } 33 | report["entities"] = annotations["0"]["entities"] 34 | report["label"] = (np.array(label)[target_names_5_index]).tolist() 35 | labeled_data.append(report) 36 | 37 | # Save labeled data 38 | with open(output_path, 'w') as json_file: 39 | json.dump(labeled_data, json_file, indent=4) 40 | 41 | print(f"Labeled data saved to {output_path}") 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser(description="Annotate radiology reports with RadGraph and CheXbert labels.") 46 | parser.add_argument("--input_path", type=str, required=True, help="Path to input JSON file.") 47 | parser.add_argument("--output_path", type=str, required=True, help="Path to save labeled JSON file.") 48 | parser.add_argument("--device", type=str, default="cuda", help="Device to use for CheXbert inference (default: cuda).") 49 | 50 | args = parser.parse_args() 51 | 52 | process_labels(args.input_path, args.output_path, args.device) 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_valid/gen_similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | import math 5 | import tqdm 6 | import os 7 | import time 8 | 9 | from utils import chexbert_similarity, radgraph_similarity 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--query_data_file", type=str, 13 | default="FactMM-RAG/data/mimic/valid_labeled.json") 14 | parser.add_argument("--corpus_data_file", type=str, 15 | default="FactMM-RAG/data/mimic/train_labeled.json") 16 | parser.add_argument("--output_folder", type=str, 17 | default="FactMM-RAG/data/mimic/scoring_chunks_valid") 18 | parser.add_argument("--num_chunks", type=int, default=16) 19 | parser.add_argument("--chunk_id", type=int, required=True) 20 | args = parser.parse_args() 21 | 22 | with open(args.query_data_file, "r") as q, open(args.corpus_data_file, "r") as c: 23 | obj_query = json.load(q) 24 | obj_corpus = json.load(c) 25 | os.makedirs(args.output_folder, exist_ok=True) 26 | n = len(obj_query) 27 | n_corpus = len(obj_corpus) 28 | chunk_size = (n + args.num_chunks - 1) // args.num_chunks 29 | start = chunk_size * args.chunk_id 30 | end = min(chunk_size * (args.chunk_id + 1), n) 31 | print( 32 | f"Chunking({n=}) [{args.chunk_id}] ({args.chunk_id + 1}/{args.num_chunks}): [{start}, {end})", flush=True) 33 | 34 | out_chexbert_sims = np.zeros((end - start, n_corpus)) 35 | out_radgraph_sims = np.zeros((end - start, n_corpus)) 36 | a = time.time() 37 | for i in tqdm.tqdm(range(start, end), desc="outer"): 38 | query_i = obj_query[i]["finding"] 39 | for j in range(n_corpus): 40 | doc_j = obj_corpus[j]["finding"] 41 | chex_sim = chexbert_similarity(query_i, doc_j) 42 | radg_sim = radgraph_similarity(query_i, doc_j) 43 | out_chexbert_sims[i - start, j] = chex_sim 44 | out_radgraph_sims[i - start, j] = radg_sim 45 | b = time.time() 46 | print(f"Time Taken: {(b-a):0.4f} s", flush=True) 47 | print(f"Output: {out_chexbert_sims.shape=}, {out_radgraph_sims.shape=}") 48 | output_file = os.path.join(args.output_folder) 49 | print(f"Saving to: {args.output_folder}", flush=True) 50 | np.save(os.path.join(args.output_folder, 51 | f"chex_{args.chunk_id}"), out_chexbert_sims) 52 | np.save(os.path.join(args.output_folder, 53 | f"radg_{args.chunk_id}"), out_radgraph_sims) -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def exact_entity_token_if_rel_exists_reward( 3 | hypothesis_annotation_list, reference_annotation_list 4 | ): 5 | candidates = [] 6 | for annotation_list in [hypothesis_annotation_list, reference_annotation_list]: 7 | candidate = [] 8 | for entity in annotation_list.values(): 9 | if not entity["relations"]: 10 | candidate.append((entity["tokens"], entity["label"])) 11 | if entity["relations"]: 12 | candidate.append((entity["tokens"], entity["label"], True)) 13 | candidate = set(candidate) 14 | candidates.append(candidate) 15 | hypothesis_relation_token_list, reference_relation_token_list = candidates 16 | precision = ( 17 | sum( 18 | [ 19 | 1 20 | for x in hypothesis_relation_token_list 21 | if (x in reference_relation_token_list) 22 | ] 23 | ) 24 | / len(hypothesis_relation_token_list) 25 | if len(hypothesis_relation_token_list) > 0 26 | else 0.0 27 | ) 28 | recall = ( 29 | sum( 30 | [ 31 | 1 32 | for x in reference_relation_token_list 33 | if (x in hypothesis_relation_token_list) 34 | ] 35 | ) 36 | / len(reference_relation_token_list) 37 | if len(reference_relation_token_list) > 0 38 | else 0.0 39 | ) 40 | f1_score = ( 41 | (2 * precision * recall / (precision + recall)) 42 | if (precision + recall) > 0 43 | else 0.0 44 | ) 45 | return f1_score 46 | 47 | 48 | def chexbert_similarity(report, ret_report): 49 | report_label = report["label"] 50 | ret_report_label = ret_report["label"] 51 | # distance = manhattan_distance(report_label, ret_report_label) 52 | # # Calculate the similarity as the inverse of the distance plus one 53 | # # to prevent division by zero when the distance is zero. 54 | # return int(report_label == ret_report_label) 55 | return sum(1 for true, pred in zip(report_label, ret_report_label) if true == pred) / len(report_label) 56 | 57 | 58 | def radgraph_similarity(report, ret_report): 59 | report_entities = report["entities"] 60 | ret_report_entities = ret_report["entities"] 61 | partial_reward = exact_entity_token_if_rel_exists_reward( 62 | ret_report_entities, report_entities) 63 | return partial_reward -------------------------------------------------------------------------------- /data/factual_mining/build_pos_valid/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def exact_entity_token_if_rel_exists_reward( 3 | hypothesis_annotation_list, reference_annotation_list 4 | ): 5 | candidates = [] 6 | for annotation_list in [hypothesis_annotation_list, reference_annotation_list]: 7 | candidate = [] 8 | for entity in annotation_list.values(): 9 | if not entity["relations"]: 10 | candidate.append((entity["tokens"], entity["label"])) 11 | if entity["relations"]: 12 | candidate.append((entity["tokens"], entity["label"], True)) 13 | candidate = set(candidate) 14 | candidates.append(candidate) 15 | hypothesis_relation_token_list, reference_relation_token_list = candidates 16 | precision = ( 17 | sum( 18 | [ 19 | 1 20 | for x in hypothesis_relation_token_list 21 | if (x in reference_relation_token_list) 22 | ] 23 | ) 24 | / len(hypothesis_relation_token_list) 25 | if len(hypothesis_relation_token_list) > 0 26 | else 0.0 27 | ) 28 | recall = ( 29 | sum( 30 | [ 31 | 1 32 | for x in reference_relation_token_list 33 | if (x in hypothesis_relation_token_list) 34 | ] 35 | ) 36 | / len(reference_relation_token_list) 37 | if len(reference_relation_token_list) > 0 38 | else 0.0 39 | ) 40 | f1_score = ( 41 | (2 * precision * recall / (precision + recall)) 42 | if (precision + recall) > 0 43 | else 0.0 44 | ) 45 | return f1_score 46 | 47 | 48 | def chexbert_similarity(report, ret_report): 49 | report_label = report["label"] 50 | ret_report_label = ret_report["label"] 51 | # distance = manhattan_distance(report_label, ret_report_label) 52 | # # Calculate the similarity as the inverse of the distance plus one 53 | # # to prevent division by zero when the distance is zero. 54 | # return int(report_label == ret_report_label) 55 | return sum(1 for true, pred in zip(report_label, ret_report_label) if true == pred) / len(report_label) 56 | 57 | 58 | def radgraph_similarity(report, ret_report): 59 | report_entities = report["entities"] 60 | ret_report_entities = ret_report["entities"] 61 | partial_reward = exact_entity_token_if_rel_exists_reward( 62 | ret_report_entities, report_entities) 63 | return partial_reward -------------------------------------------------------------------------------- /src/generator/build_nonrag_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import json 5 | import argparse 6 | import sys 7 | import tqdm 8 | 9 | parser = argparse.ArgumentParser( 10 | description="Takes in an index file, and produces a llava-formatted dataset." 11 | ) 12 | # pickle file of embeddings (output from marvel ance). obj[1] should be the array of embeddings 13 | parser.add_argument("--queries_data_path", type=str, default="./data/mimic/test.json") 14 | parser.add_argument( 15 | "--is_conversational", 16 | action="store_true", 17 | help="""Set to true if used for llava training / evaluation, Set to False if used for inference""", 18 | ) 19 | parser.add_argument( 20 | "--output_data_mode", 21 | type=str, 22 | default="finding", 23 | help="the data type that is used for ground truth", 24 | ) 25 | parser.add_argument("--output_path", type=str, default="./data/rag/vqa/vqa_test.json") 26 | args = parser.parse_args() 27 | print(f"{vars(args)=}") 28 | 29 | queries_data_path = args.queries_data_path 30 | is_conversational = args.is_conversational 31 | output_path = args.output_path 32 | output_data_mode = args.output_data_mode 33 | 34 | with open(queries_data_path, "r") as f_que: 35 | query_data = json.load(f_que) 36 | print(f"{len(query_data)=}", flush=True) 37 | 38 | output = [] 39 | 40 | for i, knn_sample_i in tqdm.tqdm(enumerate(query_data), desc="queries..."): 41 | query_data_i = query_data[ 42 | i 43 | ] # correspondent sample. Dict with keys "image", "finding", "impression" 44 | query_data_image_path = query_data_i["image"] 45 | 46 | if is_conversational: 47 | obj = { 48 | "id": i, 49 | "image": query_data_image_path, 50 | "conversations": [ 51 | { 52 | "from": "human", 53 | "value": f"Generate a radiology report from this image:", 54 | }, 55 | {"from": "gpt", "value": f"{query_data_i[output_data_mode]}"}, 56 | ], 57 | } 58 | else: 59 | obj = { 60 | "question_id": i, 61 | "image": query_data_image_path, 62 | "text": f"\nGenerate a radiology report from this image:", 63 | } 64 | output.append(obj) 65 | 66 | print(f"Saving to: {output_path}") 67 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 68 | with open(os.path.join(output_path), "w") as f: 69 | json.dump(output, f, indent=2) 70 | -------------------------------------------------------------------------------- /src/retriever/DPR/retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import numpy as np 5 | import faiss 6 | from tqdm import tqdm 7 | 8 | def main(args): 9 | # Load test image embeddings 10 | with open(args.test_image_embedding_path, 'rb') as fin_test: 11 | test_image_embeddings = pickle.load(fin_test) 12 | 13 | # Load actual training data (JSON) 14 | with open(args.train_json_path, 'r') as fin_train_actual: 15 | train_actual = json.load(fin_train_actual) 16 | 17 | query_idx = list(range(len(test_image_embeddings))) 18 | test_image_embeddings = np.array(test_image_embeddings, dtype=np.float32) 19 | 20 | # Load train embeddings 21 | with open(args.train_embedding_path, 'rb') as fin_train: 22 | train_embedding = pickle.load(fin_train) 23 | 24 | # FAISS Index 25 | size = test_image_embeddings.shape[1] 26 | cpu_index = faiss.IndexFlatIP(size) 27 | cpu_index.add(np.array(train_embedding, dtype=np.float32)) 28 | 29 | # Search for nearest neighbors 30 | D, I = cpu_index.search(test_image_embeddings, 3) 31 | 32 | # Retrieve findings & impressions 33 | ctx_findings_impressions = [] 34 | for step, qid in enumerate(tqdm(query_idx)): 35 | cur = {"associated_impression": [], "retrieved_finding": []} 36 | for idx in I[step]: 37 | cur["associated_impression"].append(train_actual[idx]['impression']) 38 | cur["retrieved_finding"].append(train_actual[idx]['finding']) 39 | ctx_findings_impressions.append(cur) 40 | 41 | # Write all entries to JSON 42 | with open(args.output_json_path, 'w') as json_file: 43 | json.dump(ctx_findings_impressions, json_file, indent=4) 44 | 45 | print(f"Saved retrieved findings & impressions to {args.output_json_path}") 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument("--test_image_embedding_path", type=str, required=True, 51 | help="Path to test image embeddings (pkl file)") 52 | parser.add_argument("--train_json_path", type=str, required=True, 53 | help="Path to training dataset JSON file") 54 | parser.add_argument("--train_embedding_path", type=str, required=True, 55 | help="Path to training embeddings (pkl file)") 56 | parser.add_argument("--output_json_path", type=str, required=True, 57 | help="Path to output JSON file") 58 | 59 | args = parser.parse_args() 60 | main(args) 61 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from radgraph import F1RadGraph 4 | from f1chexbert import F1CheXbert 5 | from rouge import Rouge 6 | import evaluate 7 | from bert_score import BERTScorer 8 | 9 | def main(args): 10 | with open(args.ref_path, "r") as f: 11 | ref_dict = json.load(f) 12 | 13 | with open(args.pred_path, "r") as f: 14 | pred_dict = json.load(f) 15 | print(f"References size: {len(ref_dict)}, Preditions size: {len(pred_dict)}") 16 | hyps = [] 17 | refs = [] 18 | for ref, pred in zip(ref_dict, pred_dict): 19 | hyp = pred['retrieved_finding'][0] 20 | check = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0] 21 | if len(check) > 0: 22 | hyps.append(pred['retrieved_finding'][0]) 23 | refs.append(ref['finding']) 24 | 25 | assert len(hyps) == len(refs) 26 | print("The number of testing dataset:", len(hyps)) 27 | 28 | # F1RadGraph Evaluation 29 | f1radgraph = F1RadGraph(reward_level=args.radgraph_level, model_type="radgraph") 30 | score, _, _, _ = f1radgraph(hyps=hyps, refs=refs) 31 | print("F1RadGraph:", score) 32 | 33 | # F1CheXbert Evaluation 34 | f1chexbert = F1CheXbert(device=args.device) 35 | _, _, _, class_report_5 = f1chexbert(hyps=hyps, refs=refs) 36 | print("F1CheXpert:", class_report_5["micro avg"]["f1-score"]) 37 | 38 | # ROUGE Evaluation 39 | rouge = Rouge() 40 | scores = rouge.get_scores(hyps, refs, avg=True) 41 | print("Rouge-L:", scores['rouge-l']['f']) 42 | 43 | # BLEU Evaluation 44 | bleu = evaluate.load("bleu") 45 | results = bleu.compute(predictions=hyps, references=[[ref] for ref in refs]) 46 | print("BLEU-4 score:", results['precisions'][3]) 47 | 48 | # BERTScore Evaluation 49 | bert_scorer = BERTScorer(model_type=args.bert_model, 50 | num_layers=5, 51 | batch_size=64, 52 | nthreads=4, 53 | all_layers=False, 54 | idf=False, 55 | device=args.device, 56 | lang='en', 57 | rescale_with_baseline=True, 58 | baseline_path=None) 59 | _, _, f = bert_scorer.score(cands=hyps, refs=refs) 60 | print("BERTScore:", f.mean().item()) 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--ref_path", type=str, required=True, help="Path to reference JSON file") 65 | parser.add_argument("--pred_path", type=str, required=True, help="Path to predicted JSON file") 66 | parser.add_argument("--device", type=str, default="cuda", help="Computation device (cpu or cuda)") 67 | parser.add_argument("--radgraph_level", type=str, default="partial", help="RadGraph reward level") 68 | parser.add_argument("--bert_model", type=str, default="distilbert-base-uncased", help="BERTScore model type") 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /src/retriever/ANCE/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5ForConditionalGeneration, CLIPProcessor, T5Model,CLIPVisionModel 2 | from multi_model import MultiModal 3 | import numpy as np 4 | DEFAULT_IMAGE_PATCH_TOKEN = "" 5 | DEFAULT_IM_START_TOKEN = "" 6 | DEFAULT_IM_END_TOKEN = "" 7 | 8 | 9 | def load_model(args,device): 10 | clip_model_name=args.clip_model_name 11 | t5_model_name=args.t5_model_name 12 | t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name) 13 | t5_model = T5Model.from_pretrained(t5_model_name) 14 | t5_tokenizer.add_special_tokens( 15 | {"additional_special_tokens": [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]}) 16 | t5_model.resize_token_embeddings(len(t5_tokenizer)) 17 | image_processor = CLIPProcessor.from_pretrained(clip_model_name) 18 | model = MultiModal(clip_model_name, t5_model, t5_tokenizer) 19 | model = model.to(device) 20 | return t5_tokenizer, model, image_processor 21 | 22 | def get_img_patch_token_size(clip_model_name): 23 | clip_model = CLIPVisionModel.from_pretrained(clip_model_name) 24 | image_size=clip_model.config.image_size 25 | patch_size=clip_model.config.patch_size 26 | img_patch_token_size=int(image_size/patch_size)**2 27 | return img_patch_token_size 28 | 29 | def manhattan_distance(u, v): 30 | return np.sum(np.abs(u - v)) 31 | 32 | def exact_entity_token_if_rel_exists_reward( 33 | hypothesis_annotation_list, reference_annotation_list 34 | ): 35 | candidates = [] 36 | for annotation_list in [hypothesis_annotation_list, reference_annotation_list]: 37 | candidate = [] 38 | for entity in annotation_list.values(): 39 | if not entity["relations"]: 40 | candidate.append((entity["tokens"], entity["label"])) 41 | if entity["relations"]: 42 | candidate.append((entity["tokens"], entity["label"], True)) 43 | 44 | candidate = set(candidate) 45 | candidates.append(candidate) 46 | 47 | hypothesis_relation_token_list, reference_relation_token_list = candidates 48 | 49 | precision = ( 50 | sum( 51 | [ 52 | 1 53 | for x in hypothesis_relation_token_list 54 | if (x in reference_relation_token_list) 55 | ] 56 | ) 57 | / len(hypothesis_relation_token_list) 58 | if len(hypothesis_relation_token_list) > 0 59 | else 0.0 60 | ) 61 | recall = ( 62 | sum( 63 | [ 64 | 1 65 | for x in reference_relation_token_list 66 | if (x in hypothesis_relation_token_list) 67 | ] 68 | ) 69 | / len(reference_relation_token_list) 70 | if len(reference_relation_token_list) > 0 71 | else 0.0 72 | ) 73 | f1_score = ( 74 | (2 * precision * recall / (precision + recall)) 75 | if (precision + recall) > 0 76 | else 0.0 77 | ) 78 | 79 | return f1_score 80 | 81 | -------------------------------------------------------------------------------- /src/retriever/DPR/gen_hard_negatives.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pickle 3 | import faiss 4 | import pytrec_eval 5 | import argparse 6 | import json 7 | import numpy as np 8 | import pandas as pd 9 | import numpy as np 10 | import pickle 11 | from IPython import embed 12 | import os 13 | from utils import exact_entity_token_if_rel_exists_reward 14 | from collections import defaultdict 15 | def chexbert_similarity(report,ret_report): 16 | report_label = report['label'] 17 | ret_report_label = ret_report['label'] 18 | # distance = manhattan_distance(report_label, ret_report_label) 19 | # # Calculate the similarity as the inverse of the distance plus one 20 | # # to prevent division by zero when the distance is zero. 21 | return sum(1 for true, pred in zip(report_label, ret_report_label) if true == pred)/len(report_label) 22 | 23 | 24 | def radgraph_similarity(report,ret_report): 25 | report_entities = report['entities'] 26 | ret_report_entities = ret_report['entities'] 27 | partial_reward = exact_entity_token_if_rel_exists_reward(ret_report_entities,report_entities) 28 | return partial_reward 29 | 30 | 31 | 32 | 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser("") 37 | parser.add_argument("--query_path") 38 | parser.add_argument("--corpus_path") 39 | parser.add_argument("--query_embed_path") 40 | parser.add_argument("--txt_embed_path") 41 | parser.add_argument("--chexbert_threshold",type=float,default=1) 42 | parser.add_argument("--radgraph_threshold",type=float,default=0.4) 43 | parser.add_argument("--result_path") 44 | parser.add_argument("--topN",type=int,default=100) 45 | parser.add_argument("--num_top_neg",type=int,default=2) 46 | 47 | 48 | args = parser.parse_args() 49 | faiss.omp_set_num_threads(16) 50 | 51 | 52 | 53 | with open(args.query_embed_path, 'rb') as fin: 54 | query_embeds = pickle.load(fin) 55 | query_embeds = np.array(query_embeds, np.float32) 56 | 57 | cpu_index = faiss.IndexFlatIP(query_embeds.shape[1]) 58 | 59 | 60 | print("load data from {}".format(args.txt_embed_path)) 61 | with open(args.txt_embed_path, 'rb') as fin: 62 | txt_embeds = pickle.load(fin) 63 | cpu_index.add(np.array(txt_embeds, np.float32)) 64 | 65 | 66 | with open(args.query_path,"r") as f: 67 | query_data = json.load(f) 68 | with open(args.corpus_path,"r") as f: 69 | corpus_data = json.load(f) 70 | 71 | 72 | D, I = cpu_index.search(query_embeds, args.topN) 73 | 74 | query_hard_negative_id = [] 75 | for qid, query_results in tqdm(enumerate(I)): 76 | cur_query_hard_negative_id = [] 77 | for ret_id in query_results: 78 | if ret_id != qid : 79 | if chexbert_similarity(query_data[qid],corpus_data[ret_id]) < args.chexbert_threshold and radgraph_similarity(query_data[qid],corpus_data[ret_id]) < args.radgraph_threshold: 80 | cur_query_hard_negative_id.append((chexbert_similarity(query_data[qid],corpus_data[ret_id])+radgraph_similarity(query_data[qid],corpus_data[ret_id]),ret_id)) 81 | cur_query_hard_negative_id.sort(key = lambda x:x[0]) 82 | query_hard_negative_id.append([ ret_idx for _,ret_idx in cur_query_hard_negative_id[:args.num_top_neg]]) 83 | del cpu_index 84 | 85 | print("Save file!") 86 | pickle.dump(query_hard_negative_id,open(args.result_path,'wb')) 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/retriever/DPR/data.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from PIL import Image 4 | import io 5 | import numpy as np 6 | import torch 7 | from PIL import ImageFile 8 | from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler 9 | import random 10 | from tqdm import tqdm 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | DEFAULT_IMAGE_PATCH_TOKEN = "" 14 | DEFAULT_IM_START_TOKEN = "" 15 | DEFAULT_IM_END_TOKEN = "" 16 | 17 | class MedDataset(Dataset): 18 | def __init__(self, preprocess_fn, tokenizer, data,data_pos,img_special_len=49,valid_data = None): 19 | self.preprocess_fn = preprocess_fn 20 | self.tokenizer=tokenizer 21 | self.img_special_len=img_special_len 22 | self.data = data 23 | self.data_pos = data_pos 24 | self.valid_data = valid_data 25 | 26 | self.splited_data_pos_idx_pair = [] 27 | 28 | if self.valid_data != None: 29 | self.data_pos_idx_filtered = [] 30 | for idx,data_pos_instance in enumerate(self.data_pos): 31 | if len(data_pos_instance)!=0: 32 | self.data_pos_idx_filtered.append(idx) 33 | else: 34 | #Splitted training data 35 | self.splited_data_pos_idx_pair = [(qid, pos_id) for qid, pos_id_list in enumerate(self.data_pos) for pos_id in pos_id_list] 36 | 37 | 38 | 39 | 40 | 41 | def __len__(self): 42 | if self.valid_data!= None: 43 | return len( self.data_pos_idx_filtered ) 44 | else: 45 | return len(self.splited_data_pos_idx_pair ) 46 | 47 | 48 | def encode_img(self,img,report = None): 49 | img = self.preprocess_fn(images=Image.open(img), return_tensors="pt")["pixel_values"][0] 50 | if report != None: 51 | pre_token= DEFAULT_IM_START_TOKEN+" "+ DEFAULT_IMAGE_PATCH_TOKEN * self.img_special_len + DEFAULT_IM_END_TOKEN 52 | cap=pre_token+" "+report 53 | return {'pos_image': img, 'pos_report':cap} 54 | else: 55 | return {'image': img} 56 | 57 | def Collector(self, batch): 58 | query_image_inputs = [] 59 | pos_image_inputs = [] 60 | pos_report_inputs = [] 61 | 62 | processed_batch = {} 63 | for qid, example in enumerate(batch): 64 | query_image_inputs.append(example['image']) 65 | pos_image_inputs.append(example['pos_image']) 66 | pos_report_inputs.append(example['pos_report']) 67 | 68 | processed_batch['query_image_inputs'] = torch.stack(query_image_inputs, dim=0) 69 | processed_batch['pos_image_inputs'] = torch.stack(pos_image_inputs, dim=0) 70 | processed_batch['pos_report_inputs'] = self.tokenizer(pos_report_inputs, return_tensors='pt',padding=True,truncation=True) 71 | 72 | return processed_batch 73 | 74 | def __getitem__(self, index): 75 | if self.valid_data!=None: 76 | example = self.valid_data[self.data_pos_idx_filtered[index]] 77 | example_pos_id = random.choice(self.data_pos[self.data_pos_idx_filtered[index]]) 78 | else: 79 | example_id,example_pos_id = self.splited_data_pos_idx_pair[index] 80 | example = self.data[example_id] 81 | 82 | image = example['image'][0] 83 | instance = self.encode_img(image) 84 | instance_pos = self.encode_img(self.data[example_pos_id]['image'][0],self.data[example_pos_id]['finding']+" "+self.data[example_pos_id]['impression']) 85 | instance.update(instance_pos) 86 | 87 | return instance 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/retriever/ANCE/multi_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPModel, T5ForConditionalGeneration, T5Tokenizer, CLIPVisionModel 4 | import numpy as np 5 | 6 | class MultiModal(nn.Module): 7 | def __init__(self, clip_model_name, t5_model, t5_tokenizer): 8 | super(MultiModal, self).__init__() 9 | # vision model 10 | self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name) 11 | # language model 12 | self.t5_model = t5_model 13 | self.t5_tokenizer = t5_tokenizer 14 | # projector 15 | self.image_dims = self.clip_model.config.hidden_size 16 | self.text_dims = self.t5_model.config.hidden_size 17 | self.projector = nn.Linear(self.image_dims, self.text_dims) 18 | # logit_scale 19 | clip = CLIPModel.from_pretrained(clip_model_name) 20 | self.logit_scale = clip.logit_scale 21 | 22 | def encode_images_only(self, images): 23 | 24 | image_embeddings = self.clip_model(images, output_hidden_states=True) 25 | # get the patch image representations (except the cls token) 26 | image_embeddings = image_embeddings.last_hidden_state[:, 1:, :] 27 | image_embeddings = self.projector(image_embeddings) 28 | return image_embeddings 29 | 30 | def get_text_inputs_embeds(self, text_inputs, device): 31 | input_ids = text_inputs["input_ids"].to(device) 32 | input_embeddings = self.t5_model.get_input_embeddings() 33 | text_inputs_embeds = input_embeddings(input_ids) 34 | return text_inputs_embeds 35 | 36 | def get_images_with_caption_inputs_embeds(self, images, img_caps, device): 37 | image_embeddings = self.encode_images_only(images) 38 | img_caps_input_embs = self.get_text_inputs_embeds(img_caps, device) 39 | img_special_token_size = image_embeddings.size(1) 40 | merge_input_embs = torch.cat((img_caps_input_embs[:, 0:1, :], image_embeddings, img_caps_input_embs[:, img_special_token_size+1:, :]), 41 | dim=1) 42 | return merge_input_embs 43 | 44 | def get_rep(self, inputs_embeds, input, device): 45 | if input == None: 46 | attention_mask = None 47 | else: 48 | attention_mask = input['attention_mask'].to(device) 49 | decoder_input_ids = torch.zeros((inputs_embeds.shape[0], 1), dtype=torch.long) 50 | decoder_input_ids = decoder_input_ids.to(device) 51 | outputs = self.t5_model( 52 | input_ids=None, 53 | attention_mask=attention_mask, 54 | inputs_embeds=inputs_embeds, 55 | decoder_input_ids=decoder_input_ids, 56 | return_dict=True 57 | ) 58 | hidden = outputs.last_hidden_state 59 | rep = hidden[:, 0, :] 60 | return rep, hidden 61 | 62 | def forward(self, images=None, text_inputs=None, device=None): 63 | if images != None and text_inputs != None: 64 | merge_embs = self.get_images_with_caption_inputs_embeds(images, text_inputs, device) 65 | merge_imgs_rep, _ = self.get_rep(merge_embs, text_inputs, device) 66 | return merge_imgs_rep 67 | elif images != None and text_inputs == None: 68 | image_embs = self.encode_images_only(images) 69 | image_rep, _ = self.get_rep(image_embs, text_inputs, device) 70 | return image_rep 71 | elif images == None and text_inputs != None: 72 | text_embs = self.get_text_inputs_embeds(text_inputs, device) 73 | text_rep, _ = self.get_rep(text_embs, text_inputs, device) 74 | return text_rep 75 | else: 76 | raise ValueError("the input is error! ") 77 | -------------------------------------------------------------------------------- /src/retriever/DPR/multi_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPModel, T5ForConditionalGeneration, T5Tokenizer, CLIPVisionModel 4 | import numpy as np 5 | 6 | class MultiModal(nn.Module): 7 | def __init__(self, clip_model_name, t5_model, t5_tokenizer): 8 | super(MultiModal, self).__init__() 9 | # vision model 10 | self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name) 11 | # language model 12 | self.t5_model = t5_model 13 | self.t5_tokenizer = t5_tokenizer 14 | # projector 15 | self.image_dims = self.clip_model.config.hidden_size 16 | self.text_dims = self.t5_model.config.hidden_size 17 | self.projector = nn.Linear(self.image_dims, self.text_dims) 18 | # logit_scale 19 | clip = CLIPModel.from_pretrained(clip_model_name) 20 | self.logit_scale = clip.logit_scale 21 | 22 | def encode_images_only(self, images): 23 | 24 | image_embeddings = self.clip_model(images, output_hidden_states=True) 25 | # get the patch image representations (except the cls token) 26 | image_embeddings = image_embeddings.last_hidden_state[:, 1:, :] 27 | image_embeddings = self.projector(image_embeddings) 28 | return image_embeddings 29 | 30 | def get_text_inputs_embeds(self, text_inputs, device): 31 | input_ids = text_inputs["input_ids"].to(device) 32 | input_embeddings = self.t5_model.get_input_embeddings() 33 | text_inputs_embeds = input_embeddings(input_ids) 34 | return text_inputs_embeds 35 | 36 | def get_images_with_caption_inputs_embeds(self, images, img_caps, device): 37 | image_embeddings = self.encode_images_only(images) 38 | img_caps_input_embs = self.get_text_inputs_embeds(img_caps, device) 39 | img_special_token_size = image_embeddings.size(1) 40 | merge_input_embs = torch.cat((img_caps_input_embs[:, 0:1, :], image_embeddings, img_caps_input_embs[:, img_special_token_size+1:, :]), 41 | dim=1) 42 | return merge_input_embs 43 | 44 | def get_rep(self, inputs_embeds, input, device): 45 | if input == None: 46 | attention_mask = None 47 | else: 48 | attention_mask = input['attention_mask'].to(device) 49 | decoder_input_ids = torch.zeros((inputs_embeds.shape[0], 1), dtype=torch.long) 50 | decoder_input_ids = decoder_input_ids.to(device) 51 | outputs = self.t5_model( 52 | input_ids=None, 53 | attention_mask=attention_mask, 54 | inputs_embeds=inputs_embeds, 55 | decoder_input_ids=decoder_input_ids, 56 | return_dict=True 57 | ) 58 | hidden = outputs.last_hidden_state 59 | rep = hidden[:, 0, :] 60 | return rep, hidden 61 | 62 | def forward(self, images=None, text_inputs=None, device=None): 63 | if images != None and text_inputs != None: 64 | merge_embs = self.get_images_with_caption_inputs_embeds(images, text_inputs, device) 65 | merge_imgs_rep, _ = self.get_rep(merge_embs, text_inputs, device) 66 | return merge_imgs_rep 67 | elif images != None and text_inputs == None: 68 | image_embs = self.encode_images_only(images) 69 | image_rep, _ = self.get_rep(image_embs, text_inputs, device) 70 | return image_rep 71 | elif images == None and text_inputs != None: 72 | text_embs = self.get_text_inputs_embeds(text_inputs, device) 73 | text_rep, _ = self.get_rep(text_embs, text_inputs, device) 74 | return text_rep 75 | else: 76 | raise ValueError("the input is error! ") 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # FactMM-RAG 177 | data/rag/* 178 | data/mimic/* 179 | data/ANCE/* 180 | LLaVA/ 181 | 182 | ANCE/ 183 | cmd.sh 184 | src/generator/gitignore/ 185 | .vscode/ -------------------------------------------------------------------------------- /src/retriever/ANCE/data.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from PIL import Image 4 | import io 5 | import numpy as np 6 | import torch 7 | from PIL import ImageFile 8 | from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler 9 | import random 10 | from tqdm import tqdm 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | DEFAULT_IMAGE_PATCH_TOKEN = "" 14 | DEFAULT_IM_START_TOKEN = "" 15 | DEFAULT_IM_END_TOKEN = "" 16 | 17 | class MedDataset(Dataset): 18 | def __init__(self, preprocess_fn, tokenizer, data,data_pos,data_neg,img_special_len=49,valid_data = None): 19 | self.preprocess_fn = preprocess_fn 20 | self.tokenizer=tokenizer 21 | self.img_special_len=img_special_len 22 | self.data = data 23 | self.data_pos = data_pos 24 | self.data_neg = data_neg 25 | self.valid_data = valid_data 26 | 27 | self.splited_data_pos_idx_pair = [] 28 | 29 | if self.valid_data != None: 30 | self.data_pos_idx_filtered = [] 31 | for idx,data_pos_instance in enumerate(self.data_pos): 32 | if len(data_pos_instance)!=0: 33 | self.data_pos_idx_filtered.append(idx) 34 | else: 35 | #Splitted training data 36 | self.splited_data_pos_idx_pair = [(qid, pos_id) for qid, pos_id_list in enumerate(self.data_pos) for pos_id in pos_id_list] 37 | 38 | 39 | 40 | 41 | 42 | def __len__(self): 43 | if self.valid_data!= None: 44 | return len( self.data_pos_idx_filtered ) 45 | else: 46 | return len(self.splited_data_pos_idx_pair ) 47 | 48 | 49 | def encode_img(self,img,report = None,neg = False): 50 | img = self.preprocess_fn(images=Image.open(img), return_tensors="pt")["pixel_values"][0] 51 | if report != None: 52 | pre_token= DEFAULT_IM_START_TOKEN+" "+ DEFAULT_IMAGE_PATCH_TOKEN * self.img_special_len + DEFAULT_IM_END_TOKEN 53 | cap=pre_token+" "+report 54 | if not neg: 55 | return {'pos_image': img, 'pos_report':cap} 56 | else: 57 | return {'neg_image': img, 'neg_report':cap} 58 | else: 59 | return {'image': img} 60 | 61 | 62 | 63 | def Collector(self, batch): 64 | query_image_inputs = [] 65 | pos_neg_image_inputs = [] 66 | pos_neg_report_inputs = [] 67 | 68 | 69 | processed_batch = {} 70 | label_offset = 0 71 | label_list = [] 72 | for qid, example in enumerate(batch): 73 | query_image_inputs.append(example['image']) 74 | pos_neg_image_inputs.append(example['pos_image']) 75 | pos_neg_report_inputs.append(example['pos_report']) 76 | for example_neg in example["neg_inputs"]: 77 | pos_neg_image_inputs.append(example_neg['neg_image']) 78 | pos_neg_report_inputs.append(example_neg['neg_report']) 79 | label_list.append(label_offset) 80 | label_offset += 1+len(example["neg_inputs"]) 81 | 82 | assert len(pos_neg_image_inputs) == len(pos_neg_report_inputs) 83 | 84 | processed_batch['query_image_inputs'] = torch.stack(query_image_inputs, dim=0) 85 | processed_batch['pos_neg_image_inputs'] = torch.stack(pos_neg_image_inputs, dim=0) 86 | processed_batch['pos_neg_report_inputs'] = self.tokenizer(pos_neg_report_inputs, return_tensors='pt',padding=True,truncation=True) 87 | processed_batch['targets'] = torch.tensor(label_list,dtype=torch.long) 88 | 89 | return processed_batch 90 | 91 | def __getitem__(self, index): 92 | if self.valid_data!=None: 93 | example_id = self.data_pos_idx_filtered[index] 94 | example_pos_id = random.choice(self.data_pos[example_id]) 95 | example = self.valid_data[example_id] 96 | else: 97 | example_id,example_pos_id = self.splited_data_pos_idx_pair[index] 98 | example = self.data[example_id] 99 | 100 | example_neg_id_list = self.data_neg[example_id] 101 | # random.shuffle(example_neg_id_list) 102 | # example_neg_id_list = example_neg_id_list[:3] 103 | 104 | image = example['image'][0] 105 | instance = self.encode_img(image) 106 | instance_pos = self.encode_img(self.data[example_pos_id]['image'][0],self.data[example_pos_id]['finding']+" "+self.data[example_pos_id]['impression']) 107 | instance.update(instance_pos) 108 | instance['neg_inputs'] = [] 109 | 110 | for example_neg_id in example_neg_id_list: 111 | instance['neg_inputs'].append(self.encode_img(self.data[example_neg_id]['image'][0],self.data[example_neg_id]['finding']+" "+self.data[example_neg_id]['impression'],True)) 112 | 113 | return instance 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_valid/gen_topk_pos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.ma as ma 3 | import argparse 4 | import json 5 | import math 6 | import tqdm 7 | import os 8 | import time 9 | import pickle 10 | 11 | # Does not support removing self 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--from_folder", type=str, required=True) 14 | parser.add_argument("--do_chex", action="store_true") 15 | parser.add_argument("--do_radg", action="store_true") 16 | parser.add_argument("--skip_bad_sample", action="store_true", 17 | help="Activate to return empty sample list, in case radgraph score with self is 0 (usually indicates bad data sample)") 18 | parser.add_argument("--n", type=int, required=True) 19 | parser.add_argument("--num_chunks", type=int, required=True) 20 | parser.add_argument("--output_folder", type=str, required=True) 21 | parser.add_argument("--pre_mask", action="store_true", 22 | help="Filter all rows by chexbert or radgraph value, before doing top-k operation") 23 | parser.add_argument("--pre_mask_chex", type=float, 24 | help="Lower bound for chexbert filtering", default=0) 25 | parser.add_argument("--pre_mask_radg", type=float, 26 | help="Lower bound for radgraph filtering", default=0) 27 | parser.add_argument("--top_k", type=int) 28 | np.random.seed(42) 29 | 30 | args = parser.parse_args() 31 | print(vars(args), flush=True) 32 | if not args.do_chex and not args.do_radg: 33 | raise Exception("Must have at least radgraph or chexbert mode") 34 | 35 | k = args.top_k 36 | positive_list = [] 37 | statistics = { 38 | "Self is retrieved": 0, 39 | "A score of 2 is achieved": 0, 40 | "Diagonal Radgraph is not 1": 0, 41 | "DropSelf - didn't use self": 0, 42 | "DropSelf - used self": 0 43 | } 44 | bad_list = [] 45 | 46 | n = args.n 47 | chunk_size = (n + args.num_chunks - 1) // args.num_chunks 48 | processed = 0 49 | 50 | for chunk_id in range(args.num_chunks): 51 | start = chunk_size * chunk_id 52 | end = min(chunk_size * (chunk_id + 1), n) 53 | 54 | # Obtain Statistics 55 | tensor = None 56 | chexbert = None 57 | radgraph = None 58 | if args.do_chex: 59 | chexbert = np.load(os.path.join(args.from_folder, f"chex_{chunk_id}.npy")) 60 | tensor = chexbert 61 | if chunk_id == 0: 62 | print("Loading Chexbert Tensor", flush=True) 63 | if args.do_radg: 64 | radgraph = np.load(os.path.join( 65 | args.from_folder, f"radg_{chunk_id}.npy")) 66 | if tensor is None: 67 | tensor = radgraph 68 | if chunk_id == 0: 69 | print("Loading Radgraph Tensor", flush=True) 70 | else: 71 | assert tensor.shape == radgraph.shape 72 | tensor = tensor + radgraph 73 | if chunk_id == 0: 74 | print("Adding Radgraph Tensor", flush=True) 75 | 76 | # Masking Operation before top-k 77 | ind = None # index of top-k elements 78 | ind_mask = None # if pre_mask, then this defines a mask of valid elements 79 | if not args.pre_mask: 80 | ind = np.argpartition(tensor, -k, axis=1)[:, -k:] 81 | ind_mask = np.ones_like(ind) 82 | else: 83 | ind = [] 84 | assert args.do_chex 85 | assert args.do_radg 86 | chex_mask = chexbert >= args.pre_mask_chex 87 | radg_mask = radgraph >= args.pre_mask_radg 88 | mask = np.logical_and(chex_mask, radg_mask) 89 | # set things that don't satisfy conditions to -1 90 | masked_tensor = np.where(mask, tensor, -1) 91 | # things that don't satisfy filter will be the smallest 92 | ind = np.argpartition(masked_tensor, -k, axis=1)[:, -k:] 93 | # get a mask of all indices that yield a (radg + chex) != -1 94 | ind_mask = np.take_along_axis(masked_tensor, ind, axis=1) != -1 95 | 96 | # Filtering Rows 97 | for row_idx, (row, row_mask) in enumerate(zip(ind, ind_mask)): 98 | if start + row_idx in row: 99 | statistics["Self is retrieved"] += 1 100 | if np.allclose(np.max(tensor[row_idx, row]), 2): 101 | statistics["A score of 2 is achieved"] += 1 102 | # if diagonal radgraph is not zero, it's invalid. Skip sample 103 | if args.skip_bad_sample and \ 104 | radgraph is not None and \ 105 | not np.allclose(radgraph[row_idx][start + row_idx], 1): 106 | statistics["Diagonal Radgraph is not 1"] += 1 107 | bad_list.append(start + row_idx) 108 | positive_list.append([]) 109 | continue 110 | # extract only the elements that are allowed by indicies mask 111 | extracted_row = row[np.nonzero(row_mask)[0]] 112 | positive_list.append(extracted_row.tolist()) 113 | processed += len(ind) 114 | 115 | statistics["n_set"] = n 116 | statistics["n_actual"] = processed 117 | print(json.dumps(statistics, indent=2), flush=True) 118 | print(f"Bad List: {bad_list}", flush=True) 119 | 120 | os.makedirs(args.output_folder, exist_ok=True) 121 | print(f"Saving to: {args.output_folder}", flush=True) 122 | 123 | mapped = { 124 | "statistics": statistics, 125 | "bad_list": bad_list 126 | } 127 | with open(os.path.join(args.output_folder, "positive_list.pkl"), "wb") as f: 128 | pickle.dump(positive_list, f) 129 | with open(os.path.join(args.output_folder, "extraneous_data.pkl"), "wb") as f: 130 | pickle.dump(mapped, f) -------------------------------------------------------------------------------- /src/generator/build_rag_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import json 5 | import argparse 6 | import sys 7 | import tqdm 8 | 9 | parser = argparse.ArgumentParser( 10 | description="Takes in an index file, and produces a llava-formatted dataset.") 11 | # pickle file of embeddings (output from marvel ance). obj[1] should be the array of embeddings 12 | parser.add_argument('--faiss_knn_path', type=str) 13 | parser.add_argument('--queries_data_path', type=str) 14 | parser.add_argument('--corpus_data_path', type=str) 15 | parser.add_argument('--rag_data_mode', type=str, default="finding", 16 | help="the data type that is plugged into the input's prompt") 17 | parser.add_argument('--output_data_mode', type=str, default="finding", 18 | help="the data type that is used for ground truth") 19 | parser.add_argument('--test_short', action="store_true", 20 | help="Skip KNN on things that are short (< 5 words)") 21 | parser.add_argument('--is_conversational', action="store_true", 22 | help="""Set to true if used for llava training / evaluation, Set to False if used for inference""") 23 | parser.add_argument('--output_path', type=str) 24 | args = parser.parse_args() 25 | print(f"{vars(args)=}") 26 | 27 | faiss_knn_path = args.faiss_knn_path 28 | queries_data_path, corpus_data_path = args.queries_data_path, args.corpus_data_path 29 | rag_data_mode, output_data_mode = args.rag_data_mode, args.output_data_mode 30 | test_short = args.test_short 31 | is_conversational = args.is_conversational 32 | output_path = args.output_path 33 | 34 | with \ 35 | open(queries_data_path, "r") as f_que, \ 36 | open(corpus_data_path, "r") as f_cor, \ 37 | open(faiss_knn_path, "rb") as f_knn: 38 | corpus_data = json.load(f_cor) 39 | query_data = json.load(f_que) 40 | faiss_knn_data = pickle.load(f_knn) 41 | print(f"{len(corpus_data)=} {len(query_data)=}") 42 | sys.stdout.flush() 43 | 44 | 45 | def extract_paths(image_path): 46 | patient, study, file = image_path.split("/")[-3:] 47 | return patient, study, file 48 | 49 | def indexes(data_json): 50 | patient_index, study_index, _ = zip(*[extract_paths(v["image"]) for v in data_json]) 51 | return patient_index, study_index 52 | 53 | # These indexes are arrays, where arr[i] contains the patient id of patient_i 54 | query_patient_index, query_study_index = indexes(query_data) 55 | corpus_patient_index, corpus_study_index = indexes(corpus_data) 56 | 57 | 58 | output = [] 59 | 60 | # Track how many times filters fail for top-1 result 61 | ct_self_study = 0 62 | ct_self_patient = 0 63 | ct_short = 0 64 | anomalies = [] 65 | 66 | for i, knn_sample_i in tqdm.tqdm(enumerate(faiss_knn_data), desc="queries..."): 67 | query_idx = knn_sample_i["key"] # idx of query 68 | query_knn_rankings = knn_sample_i["knn_index"] 69 | query_data_i = query_data[query_idx] # correspondent sample. Dict with keys "image", "finding", "impression" 70 | query_data_image_path = query_data_i["image"] 71 | patient_id, study_id, _ = extract_paths(query_data_image_path) 72 | 73 | chosen_document_idx = None 74 | for rank, document_idx in enumerate(query_knn_rankings): 75 | # make sure aren't self-study retrieving 76 | if corpus_study_index[document_idx] == study_id: 77 | # these things just check how often the nearest-neighbor fails a certain filter 78 | ct_self_study += rank == 0 79 | 80 | # make sure aren't self-patient retrieving 81 | elif corpus_patient_index[document_idx] == patient_id: 82 | ct_self_patient += rank == 0 83 | 84 | # findings sometimes have corrupted data, e.x "a.m." or "___" 85 | # test_short shouldn't be active for impression data mode, since impressions can be short normally 86 | elif test_short and len(corpus_data[document_idx][rag_data_mode].split()) < 5: 87 | ct_short += rank == 0 88 | 89 | # if all filters pass, yield the index 90 | else: 91 | chosen_document_idx = document_idx 92 | break 93 | 94 | # Choose first document if none pass filters 95 | if chosen_document_idx is None: 96 | anomalies.append(query_idx) 97 | chosen_document_idx = query_knn_rankings[0] 98 | 99 | retrieved_doc = corpus_data[chosen_document_idx][rag_data_mode] 100 | if is_conversational: 101 | obj = { 102 | "id": i, 103 | "image": query_data_image_path, 104 | "conversations": [ 105 | { 106 | "from": "human", 107 | "value": f"Here is a report of a related patient: \"{retrieved_doc}\"\nGenerate a radiology report from this image:" 108 | }, 109 | { 110 | "from": "gpt", 111 | "value": f"{query_data_i[output_data_mode]}" 112 | } 113 | ] 114 | } 115 | else: 116 | obj = { 117 | "question_id": i, 118 | "image": query_data_image_path, 119 | "text": f"Here is a report of a related patient: \"{retrieved_doc}\"\nGenerate a radiology report from this image:" 120 | } 121 | output.append(obj) 122 | 123 | # Anomalies are data samples that weren't able to yield valid results 124 | print(f"Data Quality: ") 125 | print(f"{ct_self_study=}, {ct_self_patient=}, {ct_short=}") 126 | print(f"n={len(anomalies)} anomalies, occuring at query id's: {anomalies}") 127 | 128 | 129 | print(f"Saving to: {output_path}") 130 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 131 | with open(os.path.join(output_path), "w") as f: 132 | json.dump(output, f, indent=2) 133 | -------------------------------------------------------------------------------- /src/generator/knn.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import pickle 3 | import argparse 4 | import numpy as np 5 | import json 6 | import sys 7 | import os 8 | 9 | faiss.omp_set_num_threads(16) 10 | 11 | parser = argparse.ArgumentParser( 12 | description=""" 13 | Constructs a knn-index file. For every query embedding, 14 | obtains the related knn indices from the training corpus. 15 | """) 16 | # pickle file of embeddings (output from marvel ance). obj[1] should be the array of embeddings 17 | parser.add_argument('--query_embedding_file', type=str, 18 | help=""" 19 | A pickle file containing embeddings. In format 20 | { 21 | idx: { 22 | "image_embedding": nd.array(d), 23 | "finding_embedding": nd.array(d), 24 | "impression_embedding": nd.array(d) 25 | } 26 | } 27 | 28 | Also can be a tensor of size: (n, d) 29 | """) 30 | parser.add_argument('--corpus_embedding_file', type=str) 31 | parser.add_argument('--query_data_file', type=str, 32 | help=""" 33 | A path to a json file containing data in format, likely train.json: 34 | [ 35 | { 36 | "image": str, 37 | "finding": str, 38 | "impression": str 39 | } 40 | ]""", default="./data/mimic/train.json") 41 | parser.add_argument('--corpus_data_file', type=str, 42 | help="The json path to corpus data (probably train.json)", default="./data/mimic/test.json") 43 | parser.add_argument("--query_key", type=str, default=None, 44 | help="The key to use to lookup in query_embedding_file. Ignored if query is a numpy array") 45 | parser.add_argument("--corpus_key", type=str, default=None, 46 | help="The key to use to lookup in corpus_embedding_file. Ignored if query is a numpy array") 47 | parser.add_argument('--nlist', type=int, default=100, 48 | help="faiss nlist value. If -1, will do dense indexing") 49 | parser.add_argument("--nprobe", type=int, default=10, help="faiss nprobe value") 50 | parser.add_argument("--k", type=int, default=20, 51 | help="faiss 'k'-nearest-neighbor value") 52 | parser.add_argument("--results_k", type=int, default=5, 53 | help="how many (k) text results (e.x. finding / impression) from corpus to save") 54 | parser.add_argument("--data_type", type=str, default=None, 55 | help="the type of data (e.x. query's finding) that is output in the resulting file") 56 | parser.add_argument("--output_path", type=str, 57 | help="output json file path", required=True) 58 | 59 | 60 | args = parser.parse_args() 61 | 62 | print("args", vars(args)) 63 | 64 | 65 | def process_pickle(loaded_object, key_type=None): 66 | if type(loaded_object) == np.ndarray: 67 | print(f"Loaded an ndarray of shape: {loaded_object.shape}") 68 | keys = list(range(len(loaded_object))) 69 | embeddings = loaded_object 70 | else: 71 | print(f"Assuming loaded a dictionary") 72 | assert key_type 73 | keys, embeddings = zip( 74 | *[(i, v[key_type]) for i, v in loaded_object.items()]) 75 | keys, embeddings = list(keys), np.vstack(embeddings) 76 | return keys, embeddings 77 | 78 | 79 | with open(args.query_embedding_file, "rb") as f_query_embedding_file, \ 80 | open(args.corpus_embedding_file, "rb") as f_corpus_embedding_file, \ 81 | open(args.query_data_file, "r") as f_query_data_file, \ 82 | open(args.corpus_data_file, "r") as f_corpus_data_file: 83 | query_embeddings = pickle.load(f_query_embedding_file) 84 | corpus_embeddings = pickle.load(f_corpus_embedding_file) 85 | corpus_keys, corpus_embeddings = process_pickle( 86 | corpus_embeddings, args.corpus_key) 87 | query_keys, query_embeddings = process_pickle( 88 | query_embeddings, args.query_key) 89 | query_data = json.load(f_query_data_file) 90 | corpus_data = json.load(f_corpus_data_file) 91 | 92 | print( 93 | f"Query Keys Sequential({len(query_keys)})? {query_keys == list(range(len(query_keys)))}") 94 | print( 95 | f"Corpus Keys Sequential({len(corpus_keys)})? {corpus_keys == list(range(len(corpus_keys)))}") 96 | print( 97 | f"Query Embedding Keys match range? {set(query_keys) == set(range(len(query_keys)))}") 98 | print( 99 | f"Corpus Embedding Keys match range? {set(corpus_keys) == set(range(len(corpus_keys)))}") 100 | print(f"{query_embeddings.shape=} {corpus_embeddings.shape=}") 101 | 102 | d = query_embeddings.shape[1] 103 | if args.nlist == -1: 104 | print(f"Using Exhaustive Search, {args.k=}") 105 | cpu_index = faiss.IndexFlatIP(d) 106 | cpu_index.add(corpus_embeddings) 107 | D, I = cpu_index.search(query_embeddings, args.k) 108 | else: 109 | print(f"Using IVFFlat, {args.nlist=}, {args.nprobe=}, {args.k=}") 110 | quantizer = faiss.IndexFlatIP(d) 111 | cpu_index = faiss.IndexIVFFlat(quantizer, d, args.nlist) 112 | cpu_index.nprobe = args.nprobe 113 | cpu_index.train(corpus_embeddings) 114 | cpu_index.add(corpus_embeddings) 115 | D, I = cpu_index.search(query_embeddings, args.k) 116 | 117 | output = [] 118 | for query_number, query_key in enumerate(query_keys): 119 | obj = { 120 | "key": query_key, 121 | "knn_index": I[query_number], 122 | "similarities": D[query_number], 123 | } 124 | if args.data_type: 125 | obj[args.data_type] = query_data[query_key][args.data_type], 126 | if args.results_k > 0 and args.data_type: 127 | qty = min(args.k, args.results_k) 128 | obj["results"] = [corpus_data[corpus_keys[corpus_index]][args.data_type] 129 | for corpus_index in I[query_number][:qty]] 130 | output.append(obj) 131 | 132 | print(f"Saving to: {args.output_path}. Items: {len(output)}") 133 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 134 | with open(args.output_path, "wb") as f: 135 | pickle.dump(output, f) 136 | -------------------------------------------------------------------------------- /data/factual_mining/build_pos_train/gen_topk_pos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.ma as ma 3 | import argparse 4 | import json 5 | import math 6 | import tqdm 7 | import os 8 | import time 9 | import pickle 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--from_folder", type=str, required=True) 13 | parser.add_argument("--do_chex", action="store_true") 14 | parser.add_argument("--do_radg", action="store_true") 15 | parser.add_argument("--ignore_self", action="store_true") 16 | parser.add_argument( 17 | "--skip_bad_sample", 18 | action="store_true", 19 | help="Activate to return empty sample list, in case radgraph score with self is 0 (usually indicates bad data sample)", 20 | ) 21 | parser.add_argument("--num_chunks", type=int, required=True) 22 | parser.add_argument("--n", type=int, required=True) 23 | parser.add_argument("--chunk_id", type=int, required=True) 24 | parser.add_argument("--output_file", type=str, required=True) 25 | parser.add_argument( 26 | "--pre_mask", 27 | action="store_true", 28 | help="Filter all rows by chexbert or radgraph value, before doing top-k operation", 29 | ) 30 | parser.add_argument( 31 | "--pre_mask_chex", type=float, help="Lower bound for chexbert filtering", default=0 32 | ) 33 | parser.add_argument( 34 | "--pre_mask_radg", type=float, help="Lower bound for radgraph filtering", default=0 35 | ) 36 | parser.add_argument("--top_k", type=int) 37 | np.random.seed(42) 38 | args = parser.parse_args() 39 | print(vars(args), flush=True) 40 | if not args.do_chex and not args.do_radg: 41 | raise Exception("Must have at least radgraph or chexbert mode") 42 | 43 | print(args.output_file) 44 | 45 | k = args.top_k 46 | chunk_id = args.chunk_id 47 | positive_list = [] 48 | statistics = { 49 | "Self is retrieved": 0, 50 | "A score of 2 is achieved": 0, 51 | "Diagonal Radgraph is not 1": 0, 52 | "DropSelf - didn't use self": 0, 53 | "DropSelf - used self": 0, 54 | } 55 | bad_list = [] 56 | 57 | n = args.n 58 | chunk_size = (n + args.num_chunks - 1) // args.num_chunks 59 | start = chunk_size * args.chunk_id 60 | end = min(chunk_size * (args.chunk_id + 1), n) 61 | print(f"Chunk: [{start=} -> {end=}] / [0:{n-1}]") 62 | 63 | # Obtain Statistics 64 | tensor = None 65 | chexbert = None 66 | radgraph = None 67 | if args.do_chex: 68 | chexbert = np.load(os.path.join(args.from_folder, f"chex_{chunk_id}.npy")) 69 | tensor = chexbert 70 | if chunk_id == 0: 71 | print("Loading Chexbert Tensor", flush=True) 72 | if args.do_radg: 73 | radgraph = np.load(os.path.join(args.from_folder, f"radg_{chunk_id}.npy")) 74 | if tensor is None: 75 | tensor = radgraph 76 | if chunk_id == 0: 77 | print("Loading Radgraph Tensor", flush=True) 78 | else: 79 | assert tensor.shape == radgraph.shape 80 | tensor = tensor + radgraph 81 | if chunk_id == 0: 82 | print("Adding Radgraph Tensor", flush=True) 83 | 84 | # Masking Operation before top-k 85 | # In case of ignore-self and we retrieve self, we want an extra item to take 86 | actual_k = k + 1 if args.ignore_self else k 87 | ind = [] # index of top-k elements 88 | ind_mask = None # if pre_mask, then this defines a mask of valid elements 89 | assert args.do_chex 90 | assert args.do_radg 91 | chex_mask = chexbert >= args.pre_mask_chex 92 | radg_mask = radgraph >= args.pre_mask_radg 93 | mask = np.logical_and(chex_mask, radg_mask) 94 | # set things that don't satisfy conditions to -1 95 | masked_tensor = np.where(mask, tensor, -1) 96 | # things that don't satisfy filter will be the smallest 97 | ind = np.argpartition(masked_tensor, -actual_k, axis=1)[:, -actual_k:] 98 | # get a mask of all indices that yield a (radg + chex) != -1 99 | ind_mask = np.take_along_axis(masked_tensor, ind, axis=1) != -1 100 | 101 | # Filtering Rows 102 | for row_idx, (row, row_mask) in enumerate(zip(ind, ind_mask)): 103 | if start + row_idx in row: 104 | statistics["Self is retrieved"] += 1 105 | if np.allclose(np.max(tensor[row_idx, row]), 2): 106 | statistics["A score of 2 is achieved"] += 1 107 | # if diagonal radgraph is not zero, it's invalid (findings is meaningless). Skip sample 108 | if ( 109 | args.skip_bad_sample 110 | and radgraph is not None 111 | and not np.allclose(radgraph[row_idx][start + row_idx], 1) 112 | ): 113 | statistics["Diagonal Radgraph is not 1"] += 1 114 | bad_list.append(start + row_idx) 115 | positive_list.append([]) 116 | continue 117 | if args.ignore_self: 118 | remaining_row_items = None 119 | remaining_row_mask = None 120 | if start + row_idx in row: 121 | # drop self -> k items 122 | idx_self = np.where(row == start + row_idx)[0][0] 123 | remaining_row_items = np.delete(row, idx_self) 124 | remaining_row_mask = np.delete(row_mask, idx_self) 125 | statistics["DropSelf - used self"] += 1 126 | else: 127 | # just choose the highest ones (right of partition, always >=) 128 | remaining_row_items = row[1:] 129 | remaining_row_mask = row_mask[1:] 130 | statistics["DropSelf - didn't use self"] += 1 131 | assert len(remaining_row_items) == k 132 | assert len(remaining_row_mask) == k 133 | extracted_row = remaining_row_items[np.nonzero(remaining_row_mask)[0]] 134 | positive_list.append(extracted_row.tolist()) 135 | else: 136 | # extract only the elements that are allowed by indicies mask 137 | extracted_row = row[np.nonzero(row_mask)[0]] 138 | positive_list.append(extracted_row.tolist()) 139 | 140 | if args.ignore_self: 141 | print("Verifying if all samples don't contain self...", flush=True) 142 | assert all( 143 | [ 144 | not (row_idx + start in positive_list) 145 | for row_idx, positive_list in enumerate(positive_list) 146 | ] 147 | ) 148 | print( 149 | f"Min: {min(sum(positive_list, []))}, Max: {max(sum(positive_list, []))}", 150 | flush=True, 151 | ) 152 | 153 | statistics["n"] = len(ind) 154 | print(json.dumps(statistics, indent=2), flush=True) 155 | print(f"Bad List: {bad_list}", flush=True) 156 | 157 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 158 | print(f"Saving to: {args.output_file}", flush=True) 159 | 160 | mapped = { 161 | "statistics": statistics, 162 | "positive_list": positive_list, 163 | "bad_list": bad_list, 164 | } 165 | with open(args.output_file, "wb") as f: 166 | pickle.dump(mapped, f) 167 | -------------------------------------------------------------------------------- /src/retriever/DPR/gen_embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import argparse 6 | import os.path as op 7 | import time 8 | import pickle 9 | import os 10 | from PIL import Image 11 | import io 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler 14 | 15 | from IPython import embed 16 | from PIL import ImageFile 17 | import pyarrow as pa 18 | import pandas as pd 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | from utils import load_model,get_img_patch_token_size 21 | 22 | 23 | 24 | def gen_img_embeddings(model, valid_reader, outpath): 25 | model.eval() 26 | all_img_embeddings = [] 27 | 28 | for step, batch in tqdm(enumerate(valid_reader)): 29 | with torch.no_grad(): 30 | 31 | embeddings = model(batch['img_inputs'].cuda(), None, device) 32 | embeddings = F.normalize(embeddings, dim=-1).cpu() 33 | all_img_embeddings.append(embeddings) 34 | 35 | all_img_embeddings = torch.cat(all_img_embeddings, dim=0).numpy() 36 | with open(outpath, 'wb') as fout: 37 | pickle.dump(all_img_embeddings, fout) 38 | 39 | def gen_txt_embeddings(model, valid_reader, findings_outpath): 40 | model.eval() 41 | all_findings_embeddings = [] 42 | for step, batch in tqdm(enumerate(valid_reader)): 43 | with torch.no_grad(): 44 | findings_embeddings = model(None, batch["findings_inputs"], device) 45 | findings_embeddings = F.normalize(findings_embeddings, dim=-1).cpu() 46 | all_findings_embeddings.append(findings_embeddings) 47 | 48 | all_findings_embeddings = torch.cat(all_findings_embeddings, dim=0).numpy() 49 | 50 | with open(findings_outpath, 'wb') as fout: 51 | pickle.dump(all_findings_embeddings, fout) 52 | 53 | 54 | class MimicImgDataset(Dataset): 55 | def __init__(self, args, img_path, preprocess, tokenizer): 56 | 57 | self.preprocess_fn = preprocess 58 | self.tokenizer = tokenizer 59 | self.img_paths = [] 60 | 61 | with open(img_path, "r") as fin: 62 | images = json.load(fin) 63 | for i in range(len(images)): 64 | self.img_paths.append(images[i]['image'][0]) 65 | 66 | def __len__(self): 67 | return len(self.img_paths) 68 | 69 | def encode_img(self, img, idx): 70 | img = self.preprocess_fn(images=Image.open(img), return_tensors="pt")["pixel_values"][0] 71 | return {'img': img} 72 | 73 | def Collector(self, batch): 74 | img_inputs = [] 75 | 76 | for example in batch: 77 | img_inputs.append(example['img_inputs']) 78 | 79 | processed_batch = {} 80 | processed_batch['img_inputs'] = torch.stack(img_inputs, dim=0) 81 | 82 | return processed_batch 83 | 84 | def __getitem__(self, index): 85 | img_inputs = self.encode_img(self.img_paths[index], index) 86 | instance = { 87 | 'img_inputs': img_inputs['img'] 88 | } 89 | 90 | 91 | return instance 92 | 93 | 94 | class MimicTxtDataset(Dataset): 95 | def __init__(self, data_path, tokenizer): 96 | self.tokenizer = tokenizer 97 | self.findings = [] 98 | 99 | with open(data_path, "r") as fin: 100 | text_data = json.load(fin) 101 | for instance in text_data: 102 | self.findings.append(instance["finding"]) 103 | 104 | 105 | def __len__(self): 106 | return len(self.findings) 107 | 108 | def Collector(self, batch): 109 | processed_batch = { 110 | 'findings_inputs': self.tokenizer(batch, return_tensors='pt', 111 | padding=True, truncation=True) 112 | } 113 | return processed_batch 114 | 115 | def __getitem__(self, index): 116 | return self.findings[index] 117 | 118 | 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 124 | parser = argparse.ArgumentParser("") 125 | parser.add_argument("--t5_model_name", type=str, default='OpenMatch/t5-ance') 126 | parser.add_argument("--clip_model_name",type=str,default='openai/clip-vit-base-patch32') 127 | parser.add_argument("--saved_ckpt",type=str,default='/FactMM-RAG/src/retriever/output/dpr.best.pt') 128 | parser.add_argument("--train_path",type=str,default='/FactMM-RAG/data/mimic/train.json') 129 | parser.add_argument("--train_path",type=str,default='/FactMM-RAG/data/mimic/valid.json') 130 | parser.add_argument("--test_path",type=str,default='/FactMM-RAG/data/mimic/test.json') 131 | 132 | parser.add_argument("--output_train_image_path",type=str,default='/FactMM-RAG/DPR/embedding/train_embedding_image.pkl') 133 | parser.add_argument("--output_train_finding_path",type=str,default='/FactMM-RAG/DPR/embedding/train_embedding_finding.pkl') 134 | parser.add_argument("--output_valid_image_path",type=str,default='/FactMM-RAG/DPR/embedding/valid_embedding_image.pkl') 135 | parser.add_argument("--output_test_image_path",type=str,default='/FactMM-RAG/DPR/embedding/test_embedding_image.pkl') 136 | 137 | args = parser.parse_args() 138 | 139 | t5_tokenizer, model, image_processor = load_model(args,device) 140 | model.load_state_dict(torch.load(args.saved_ckpt,map_location='cuda:0')['model'],strict =False) 141 | model.cuda() 142 | 143 | args.img_patch_token_size=get_img_patch_token_size(args.clip_model_name) 144 | train_path = args.train_path 145 | test_path = args.test_path 146 | valid_path = args.valid_path 147 | 148 | txt_data = MimicTxtDataset(train_path, t5_tokenizer) 149 | sampler = SequentialSampler(txt_data) 150 | txt_reader = DataLoader(dataset=txt_data, sampler=sampler, num_workers=10, 151 | batch_size=32, collate_fn=txt_data.Collector) 152 | gen_txt_embeddings(model, txt_reader, args.output_train_finding_path) 153 | 154 | img_data = MimicImgDataset(args, train_path, image_processor, t5_tokenizer) 155 | sampler = SequentialSampler(img_data) 156 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 157 | batch_size=32, collate_fn=img_data.Collector) 158 | gen_img_embeddings(model, img_reader, args.output_train_image_path) 159 | 160 | img_data = MimicImgDataset(args, valid_path, image_processor, t5_tokenizer) 161 | sampler = SequentialSampler(img_data) 162 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 163 | batch_size=32, collate_fn=img_data.Collector) 164 | gen_img_embeddings(model, img_reader, args.output_valid_image_path) 165 | 166 | 167 | 168 | img_data = MimicImgDataset(args, test_path, image_processor, t5_tokenizer) 169 | sampler = SequentialSampler(img_data) 170 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 171 | batch_size=32, collate_fn=img_data.Collector) 172 | gen_img_embeddings(model, img_reader, args.output_test_image_path) 173 | 174 | 175 | -------------------------------------------------------------------------------- /src/retriever/ANCE/gen_embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import argparse 6 | import os.path as op 7 | import time 8 | import pickle 9 | import os 10 | from PIL import Image 11 | import io 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler 14 | 15 | from IPython import embed 16 | from PIL import ImageFile 17 | import pyarrow as pa 18 | import pandas as pd 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | from utils import load_model,get_img_patch_token_size 21 | 22 | 23 | 24 | def gen_img_embeddings(model, valid_reader, outpath): 25 | model.eval() 26 | all_img_embeddings = [] 27 | 28 | for step, batch in tqdm(enumerate(valid_reader)): 29 | with torch.no_grad(): 30 | 31 | embeddings = model(batch['img_inputs'].cuda(), None, device) 32 | embeddings = F.normalize(embeddings, dim=-1).cpu() 33 | all_img_embeddings.append(embeddings) 34 | 35 | all_img_embeddings = torch.cat(all_img_embeddings, dim=0).numpy() 36 | with open(outpath, 'wb') as fout: 37 | pickle.dump(all_img_embeddings, fout) 38 | 39 | def gen_txt_embeddings(model, valid_reader, findings_outpath): 40 | model.eval() 41 | all_findings_embeddings = [] 42 | for step, batch in tqdm(enumerate(valid_reader)): 43 | with torch.no_grad(): 44 | findings_embeddings = model(None, batch["findings_inputs"], device) 45 | findings_embeddings = F.normalize(findings_embeddings, dim=-1).cpu() 46 | all_findings_embeddings.append(findings_embeddings) 47 | 48 | all_findings_embeddings = torch.cat(all_findings_embeddings, dim=0).numpy() 49 | 50 | with open(findings_outpath, 'wb') as fout: 51 | pickle.dump(all_findings_embeddings, fout) 52 | 53 | 54 | class MimicImgDataset(Dataset): 55 | def __init__(self, args, img_path, preprocess, tokenizer, image_folder=""): 56 | 57 | self.preprocess_fn = preprocess 58 | self.tokenizer = tokenizer 59 | self.image_folder = image_folder 60 | self.img_paths = [] 61 | 62 | with open(img_path, "r") as fin: 63 | images = json.load(fin) 64 | for i in range(len(images)): 65 | self.img_paths.append(images[i]['image']) 66 | 67 | def __len__(self): 68 | return len(self.img_paths) 69 | 70 | def encode_img(self, img, idx): 71 | prefix = "" if self.image_folder == "" else self.image_folder + "/" 72 | img = self.preprocess_fn(images=Image.open(prefix + img), return_tensors="pt")["pixel_values"][0] 73 | return {'img': img} 74 | 75 | def Collector(self, batch): 76 | img_inputs = [] 77 | 78 | for example in batch: 79 | img_inputs.append(example['img_inputs']) 80 | 81 | processed_batch = {} 82 | processed_batch['img_inputs'] = torch.stack(img_inputs, dim=0) 83 | 84 | return processed_batch 85 | 86 | def __getitem__(self, index): 87 | img_inputs = self.encode_img(self.img_paths[index], index) 88 | instance = { 89 | 'img_inputs': img_inputs['img'] 90 | } 91 | 92 | 93 | return instance 94 | 95 | 96 | class MimicTxtDataset(Dataset): 97 | def __init__(self, data_path, tokenizer): 98 | self.tokenizer = tokenizer 99 | self.findings = [] 100 | 101 | with open(data_path, "r") as fin: 102 | text_data = json.load(fin) 103 | for instance in text_data: 104 | self.findings.append(instance["finding"]) 105 | 106 | 107 | def __len__(self): 108 | return len(self.findings) 109 | 110 | def Collector(self, batch): 111 | processed_batch = { 112 | 'findings_inputs': self.tokenizer(batch, return_tensors='pt', 113 | padding=True, truncation=True) 114 | } 115 | return processed_batch 116 | 117 | def __getitem__(self, index): 118 | return self.findings[index] 119 | 120 | 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | parser = argparse.ArgumentParser("") 127 | parser.add_argument("--t5_model_name", type=str, default='OpenMatch/t5-ance') 128 | parser.add_argument("--clip_model_name",type=str,default='openai/clip-vit-base-patch32') 129 | parser.add_argument("--saved_ckpt",type=str,default='/FactMM-RAG/src/retriever/output/dpr.best.pt') 130 | parser.add_argument("--train_path",type=str,default='/FactMM-RAG/data/mimic/valid.json') 131 | parser.add_argument("--valid_path",type=str,default='/FactMM-RAG/data/mimic/valid.json') 132 | parser.add_argument("--test_path",type=str,default='/FactMM-RAG/data/mimic/test.json') 133 | parser.add_argument("--image_folder",type=str,default="") 134 | 135 | parser.add_argument("--output_train_image_path",type=str,default='/FactMM-RAG/DPR/embedding/train_embedding_image.pkl') 136 | parser.add_argument("--output_train_finding_path",type=str,default='/FactMM-RAG/DPR/embedding/train_embedding_finding.pkl') 137 | parser.add_argument("--output_valid_image_path",type=str,default='/FactMM-RAG/DPR/embedding/valid_embedding_image.pkl') 138 | parser.add_argument("--output_test_image_path",type=str,default='/FactMM-RAG/DPR/embedding/test_embedding_image.pkl') 139 | args = parser.parse_args() 140 | 141 | t5_tokenizer, model, image_processor = load_model(args,device) 142 | model.load_state_dict(torch.load(args.saved_ckpt,map_location='cuda:0')['model'],strict =False) 143 | model.cuda() 144 | 145 | args.img_patch_token_size=get_img_patch_token_size(args.clip_model_name) 146 | train_path = args.train_path 147 | valid_path = args.valid_path 148 | test_path = args.test_path 149 | image_folder = args.image_folder 150 | 151 | txt_data = MimicTxtDataset(train_path, t5_tokenizer) 152 | sampler = SequentialSampler(txt_data) 153 | txt_reader = DataLoader(dataset=txt_data, sampler=sampler, num_workers=10, 154 | batch_size=32, collate_fn=txt_data.Collector) 155 | gen_txt_embeddings(model, txt_reader, args.output_train_finding_path) 156 | 157 | img_data = MimicImgDataset(args, train_path, image_processor, t5_tokenizer, image_folder) 158 | sampler = SequentialSampler(img_data) 159 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 160 | batch_size=32, collate_fn=img_data.Collector) 161 | gen_img_embeddings(model, img_reader, args.output_train_image_path) 162 | 163 | img_data = MimicImgDataset(args, valid_path, image_processor, t5_tokenizer, image_folder) 164 | sampler = SequentialSampler(img_data) 165 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 166 | batch_size=32, collate_fn=img_data.Collector) 167 | gen_img_embeddings(model, img_reader, args.output_valid_image_path) 168 | 169 | 170 | img_data = MimicImgDataset(args, test_path, image_processor, t5_tokenizer, image_folder) 171 | sampler = SequentialSampler(img_data) 172 | img_reader = DataLoader(dataset=img_data, sampler=sampler, num_workers=10, 173 | batch_size=32, collate_fn=img_data.Collector) 174 | gen_img_embeddings(model, img_reader, args.output_test_image_path) 175 | 176 | 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NAACL 2025] FactMM-RAG: Fact-Aware Multimodal Retrieval Augmentation for Accurate Medical Radiology Report Generation 2 | In this work, we present FactMM-RAG, a fact-aware multimodal retrieval-augmented pipeline for generating accurate radiology reports. [[Paper Link](https://arxiv.org/abs/2407.15268)] 3 | 4 | ![Pipeline](assets/overview.png) 5 | 6 | ## 📅 Schedule 7 | 8 | - [x] Release the data preprocessing code 9 | - [x] Release the factual report pair mining code 10 | - [x] Release the retriever training code 11 | - [ ] Release the generator training code 12 | 13 | 14 | ## 📦 Requirements 15 | 1. Clone this repository and navigate to FactMM-RAG folder 16 | ```bash 17 | git clone https://github.com/cxcscmu/FactMM-RAG.git 18 | cd FactMM-RAG 19 | ``` 20 | 21 | 2. Install Package: Create conda environment 22 | 23 | ```Shell 24 | conda create -n FactMM-RAG python=3.10 -y 25 | conda activate FactMM-RAG 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | 3. Download the required dataset and checkpoint 30 | - Dataset: [MIMIC-CXR](https://vilmedic.app/papers/acl2023/) and [CheXpert](https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2) 31 | - Checkpoint: [MARVEL](https://huggingface.co/OpenMatch/marvel-ance-clueweb/tree/main) 32 | 33 | ## 📖 Data Preprocessing 34 | 1. Place the downloaded datasets in `./data/mimic` and `./data/chexpert`. We follow the official splitting and parse them into train, valid, and train files. To process the radiology dataset and generate the output JSON file, run the following command (e.g. train file parsing): 35 | ```sh 36 | python ./data/parse.py --image_paths_file ./data/mimic/train.image.tok \ 37 | --findings_file ./data/mimic/train.findings.tok \ 38 | --impressions_file ./data/mimic/train.impression.tok \ 39 | --output_json_file ./data/mimic/train.json 40 | ``` 41 | 2. Annotate reports with radiological entities, clinical relations, and diagnostic labels using RadGraph and CheXbert: 42 | ```sh 43 | python ./data/label.py --input_path ./data/mimic/train.json \ 44 | --output_path ./data/mimic/train_labeled.json \ 45 | --device cuda 46 | ``` 47 | 48 | ## 📖 Factual Report Pairs Mining 49 | 1. Generate factual similarity scores using annotations from RadGraph and CheXbert. Before running the scripts, ensure that you update the data paths accordingly. Since the training corpus is large, we utilize parallel processing with SLURM array jobs for efficiency. Run the following commands: 50 | ```bash 51 | #Query: training reports | Corpus: training reports 52 | cd ./data/factual_mining/build_pos_train/ 53 | sbatch gen_similarity.sh 54 | #Query: validation reports | Corpus: training reports 55 | cd ./data/factual_mining/build_pos_valid/ 56 | sbatch gen_similarity.sh 57 | ``` 58 | 2. Construct query and Top-K reference report pairs based on factual similarity thresholds. Run the following command: 59 | ```bash 60 | cd ./data/factual_mining/build_pos_train/ 61 | sbatch gen_topk_pos.sh 62 | sh merge_topk_pos.sh 63 | 64 | cd ./data/factual_mining/build_pos_valid/ 65 | sh gen_topk_pos.sh 66 | ``` 67 | 68 | ## 🚀 Training 69 | 70 | 1. Place the downloaded MARVEL ckpt into `./src/checkpoint/`. Train the multimodal retriever using constructed query-image and reference-report pairs, incorporating in-batch negative sampling. Additionally, an optional training stage with hard negatives can be included to further enhance performance. Run the following command: 71 | ```bash 72 | cd ./src/retriever/DPR 73 | sh train.sh 74 | sh gen_embeddings.sh 75 | 76 | #Optional ANCE Training 77 | sh gen_hard_negatives.sh 78 | cd ./src/retriever/ANCE 79 | sh train.sh 80 | sh gen_embeddings.sh 81 | ``` 82 | ## 🚀 RAG 83 | 84 | Checkpoint: https://drive.google.com/file/d/1qV-atZdKX-PwBSWzEocf5I63vuF9Ri-g/view?usp=sharing 85 | 86 | To perform RAG, generate knn-indices using: 87 | 88 | ```bash 89 | python src/generator/knn.py \ 90 | --query_embedding_file test_embedding.pkl \ 91 | --corpus_embedding_file train_embedding.pkl \ 92 | --output_path ./data/rag/knn_te2tr.pkl 93 | 94 | python src/generator/knn.py \ 95 | --query_embedding_file train_embedding.pkl \ 96 | --corpus_embedding_file train_embedding.pkl \ 97 | --output_path ./data/rag/knn_tr2tr.pkl 98 | ``` 99 | 100 | Optionally, use top-k positives for oracle-retrieval for training data. This step is computationally much 101 | more expensive, but can be parallelized and achieves much greater downstream RAG performance. 102 | 103 | ```bash 104 | sh ./data/factual_mining/build_pos_train/gen_topk_oracle_train.sh 105 | python src/generator/knn_ideal.py 106 | ``` 107 | 108 | 109 | Using a trained retriever, follow instructions in `install_llava.sh` and also set the following environment variables. 110 | 111 | ```bash 112 | export IMAGE_FOLDER="path_to_image_folder" 113 | export PROJECTOR_PATH="path_to_llava_projector" 114 | ``` 115 | 116 | Build the RAG training and test datasets using the generated query and document embeddings 117 | 118 | ```bash 119 | python ./src/generator/build_rag_dataset.py \ 120 | --faiss_knn_path ./data/rag/knn_te2tr.pkl \ 121 | --queries_data_path ./data/mimic/test.json \ 122 | --corpus_data_path ./data/mimic/train.json \ 123 | --rag_data_mode finding \ 124 | --output_data_mode finding \ 125 | --test_short \ 126 | --output_path ./data/rag/llava_data_te.json 127 | 128 | # Optionally, replace --faiss_knn_path with the training-time oracle top-1 129 | python ./src/generator/build_rag_dataset.py \ 130 | --faiss_knn_path ./data/rag/knn_tr2tr.pkl \ 131 | --queries_data_path ./data/mimic/train.json \ 132 | --corpus_data_path ./data/mimic/train.json \ 133 | --rag_data_mode finding \ 134 | --output_data_mode finding \ 135 | --test_short \ 136 | --is_conversational \ 137 | --output_path ./data/rag/llava_data_tr.json 138 | 139 | python ./src/generator/convert_json_or_jsonl.py \ 140 | --file ./data/rag/2025_03_09_end_to_end/llava_data_te.json \ 141 | --overwrite 142 | ``` 143 | 144 | Then, train a LLaVA Model and run inference & scoring 145 | 146 | ```bash 147 | ./src/generator/train_llava.sh 148 | ./src/generator/inference_llava.sh 149 | python src/generator/inference_jsonl_to_json.py ./data/rag/llava_output/test/merge_test_eval.jsonl 150 | ./src/generator/evaluate_llava.sh 151 | ``` 152 | 153 | ### Non-RAG VQA 154 | 155 | To train a non-RAG VQA model, first build the datasets: 156 | 157 | ```bash 158 | python ./src/generator/build_nonrag_dataset.py \ 159 | --queries_data_path ./data/mimic/test.json \ 160 | --output_path ./data/rag/vqa/llava_vqa_test.json 161 | 162 | python ./src/generator/build_nonrag_dataset.py \ 163 | --queries_data_path ./data/mimic/train.json \ 164 | --is_conversational \ 165 | --output_path ./data/rag/vqa/llava_vqa_train.json 166 | 167 | python ./src/generator/convert_json_or_jsonl.py \ 168 | --file ./data/rag/vqa/llava_vqa_test.json \ 169 | --overwrite 170 | ``` 171 | Then, train a llava model 172 | 173 | ```bash 174 | ./src/generator/vqa/train_llava_vqa.sh 175 | ./src/generator/vqa/inference_llava_vqa.sh 176 | ./src/generator/vqa/eval_llava_vqa.sh 177 | ``` 178 | 179 | ## 📚Citation 180 | ```bibtex 181 | @misc{sun2025factawaremultimodalretrievalaugmentation, 182 | title={Fact-Aware Multimodal Retrieval Augmentation for Accurate Medical Radiology Report Generation}, 183 | author={Liwen Sun and James Zhao and Megan Han and Chenyan Xiong}, 184 | year={2025}, 185 | eprint={2407.15268}, 186 | archivePrefix={arXiv}, 187 | primaryClass={cs.CL}, 188 | url={https://arxiv.org/abs/2407.15268}, 189 | } 190 | ``` 191 | 192 | ## 🙏Acknowledgement 193 | We use code from [LLaVA](https://github.com/haotian-liu/LLaVA) and [MARVEL](https://github.com/OpenMatch/MARVEL). We thank the authors for releasing their code. 194 | -------------------------------------------------------------------------------- /src/retriever/DPR/evaluate_retriever.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pickle 3 | import faiss 4 | import pytrec_eval 5 | import argparse 6 | import json 7 | import numpy as np 8 | import pandas as pd 9 | import numpy as np 10 | import pickle 11 | from IPython import embed 12 | import os 13 | 14 | import re 15 | 16 | def compute_mrr(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, MaxMRRRank=10): 17 | MRR = 0 18 | ranking = [] 19 | for qid in tqdm(qids_to_ranked_candidate_passages,desc=f"Evaluate MRR@{MaxMRRRank}"): 20 | if qid in qids_to_relevant_passageids: 21 | ranking.append(0) 22 | target_pid = qids_to_relevant_passageids[qid] 23 | candidate_pid = qids_to_ranked_candidate_passages[qid] 24 | for i in range(0, MaxMRRRank): 25 | if candidate_pid[i] in target_pid: 26 | MRR += 1 / (i + 1) 27 | ranking.pop() 28 | ranking.append(i + 1) 29 | break 30 | if len(ranking) == 0: 31 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 32 | 33 | MRR = MRR / len(qids_to_relevant_passageids) 34 | return MRR 35 | 36 | 37 | def convert_to_string_id(result_dict): 38 | string_id_dict = {} 39 | 40 | # format [string, dict[string, val]] 41 | for k, v in result_dict.items(): 42 | _temp_v = {} 43 | for inner_k, inner_v in v.items(): 44 | _temp_v[str(inner_k)] = inner_v 45 | 46 | string_id_dict[str(k)] = _temp_v 47 | 48 | return string_id_dict 49 | 50 | def EvalDevQuery(query_positive_id, ctx_idxs): 51 | prediction = {} # [qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2) 52 | 53 | 54 | qids_to_ranked_candidate_passages = {} 55 | for query_id, top_pid in tqdm(ctx_idxs.items(), total=len(ctx_idxs),desc="Convert prediction results"): 56 | prediction[query_id] = {} 57 | rank = 0 58 | 59 | tmp = [0] * 1000 60 | qids_to_ranked_candidate_passages[query_id] = tmp 61 | 62 | for idx in top_pid: 63 | pred_pid = idx 64 | qids_to_ranked_candidate_passages[query_id][rank] = pred_pid 65 | rank += 1 66 | prediction[query_id][pred_pid] = -rank 67 | 68 | 69 | # use out of the box evaluation script 70 | evaluator = pytrec_eval.RelevanceEvaluator( 71 | convert_to_string_id(query_positive_id), {'ndcg_cut', 'recall'}) 72 | 73 | eval_query_cnt = 0 74 | result = evaluator.evaluate(convert_to_string_id(prediction)) 75 | 76 | qids_to_relevant_passageids = {} 77 | for qid in tqdm(query_positive_id,desc="Convert ground truth results"): 78 | qids_to_relevant_passageids[qid] = [] 79 | for pid in query_positive_id[qid]: 80 | qids_to_relevant_passageids[qid].append(pid) 81 | 82 | 83 | 84 | # Initialize MRR values 85 | mrr_100 = compute_mrr(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, 100) 86 | mrr_200 = compute_mrr(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, 200) 87 | mrr_500 = compute_mrr(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, 500) 88 | mrr_1000 = compute_mrr(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, 1000) 89 | 90 | # Initialize NDCG and Recall values 91 | recall_100 = 0 92 | recall_200 = 0 93 | recall_500 = 0 94 | recall_1000 = 0 95 | 96 | ndcg_100 = 0 97 | ndcg_200 = 0 98 | ndcg_500 = 0 99 | ndcg_1000 = 0 100 | 101 | for k in tqdm(result.keys(), desc="Report results"): 102 | eval_query_cnt += 1 103 | 104 | recall_100 += result[k]["recall_100"] 105 | recall_200 += result[k]["recall_200"] 106 | recall_500 += result[k]["recall_500"] 107 | recall_1000 += result[k]["recall_1000"] 108 | 109 | ndcg_100 += result[k]["ndcg_cut_100"] 110 | ndcg_200 += result[k]["ndcg_cut_200"] 111 | ndcg_500 += result[k]["ndcg_cut_500"] 112 | ndcg_1000 += result[k]["ndcg_cut_1000"] 113 | 114 | # Calculate average values 115 | recall_100 /= eval_query_cnt 116 | recall_200 /= eval_query_cnt 117 | recall_500 /= eval_query_cnt 118 | recall_1000 /= eval_query_cnt 119 | 120 | ndcg_100 /= eval_query_cnt 121 | ndcg_200 /= eval_query_cnt 122 | ndcg_500 /= eval_query_cnt 123 | ndcg_1000 /= eval_query_cnt 124 | 125 | return recall_100, recall_200, recall_500, recall_1000, mrr_100, mrr_200, mrr_500, mrr_1000, ndcg_100, ndcg_200, ndcg_500, ndcg_1000 126 | 127 | 128 | 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser("") 133 | parser.add_argument("--query_embed_path") 134 | parser.add_argument("--txt_embed_path") 135 | parser.add_argument("--query_positive_matrix_path") 136 | parser.add_argument("--chexbert_threshold",type=float,default=1) 137 | parser.add_argument("--radgraph_threshold",type=float,default=0.4) 138 | parser.add_argument("--result_path") 139 | parser.add_argument("--topN",type=int,default=1000) 140 | 141 | 142 | 143 | args = parser.parse_args() 144 | faiss.omp_set_num_threads(16) 145 | 146 | 147 | 148 | with open(args.query_embed_path, 'rb') as fin: 149 | query_embeds = pickle.load(fin) 150 | query_embeds = np.array(query_embeds, np.float32) 151 | 152 | cpu_index = faiss.IndexFlatIP(query_embeds.shape[1]) 153 | 154 | if args.txt_embed_path: 155 | print("load data from {}".format(args.txt_embed_path)) 156 | with open(args.txt_embed_path, 'rb') as fin: 157 | txt_embeds = pickle.load(fin) 158 | cpu_index.add(np.array(txt_embeds, np.float32)) 159 | model_name = "image2finding" 160 | 161 | 162 | 163 | query_positive_matrix = pickle.load(open(args.query_positive_matrix_path,"rb")) 164 | 165 | 166 | D, I = cpu_index.search(query_embeds, args.topN) 167 | ctx_idxs = {} 168 | query_positive_id = {} 169 | for qid, np_query_results in enumerate(I): 170 | query_results = np_query_results.tolist() 171 | 172 | ctx_idxs[qid] = query_results 173 | 174 | del cpu_index 175 | 176 | 177 | 178 | for qid, query_positive_results in enumerate(query_positive_matrix): 179 | query_positive_id.setdefault(qid, {}) 180 | for ret_id in query_positive_results: 181 | query_positive_id[qid][ret_id] = 1 182 | 183 | if len(query_positive_id[qid]) == 0: 184 | del query_positive_id[qid] 185 | del ctx_idxs[qid] 186 | 187 | 188 | 189 | 190 | result = EvalDevQuery(query_positive_id, ctx_idxs) 191 | recall_100, recall_200,recall_500,recall_1000, mrr_100, mrr_200,mrr_500,mrr_1000, ndcg_100, ndcg_200,ndcg_500,ndcg_1000 = result 192 | 193 | 194 | result_path = os.path.join(args.result_path,f"chexbert_{args.chexbert_threshold}_radgraph_{args.radgraph_threshold}_top1000.txt") 195 | 196 | if not os.path.exists(result_path): 197 | with open(result_path, 'w') as fout: 198 | fout.write("Model Name\tRecall@100\tRecall@200\tRecall@500\tRecall@1000\tMRR@100\tMRR@200\tMRR@500\tMRR@1000\tNDCG@100\tNDCG@200\tNDCG@500\tNDCG@1000\n") 199 | with open(result_path, 'a') as fout: 200 | fout.write(f"{model_name}\t" 201 | f"{recall_100 * 100:.2f}\t" 202 | f"{recall_200 * 100:.2f}\t" 203 | f"{recall_500 * 100:.2f}\t" 204 | f"{recall_1000 * 100:.2f}\t" 205 | f"{mrr_100 * 100:.2f}\t" 206 | f"{mrr_200 * 100:.2f}\t" 207 | f"{mrr_500 * 100:.2f}\t" 208 | f"{mrr_1000 * 100:.2f}\t" 209 | f"{ndcg_100 * 100:.2f}\t" 210 | f"{ndcg_200 * 100:.2f}\t" 211 | f"{ndcg_500 * 100:.2f}\t" 212 | f"{ndcg_1000 * 100:.2f}\n") 213 | 214 | print(mrr_100) -------------------------------------------------------------------------------- /src/retriever/DPR/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm,trange 5 | import torch 6 | import argparse 7 | from torch import optim 8 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 9 | from data import MedDataset 10 | import torch.nn.functional as F 11 | import wandb 12 | from transformers import get_cosine_schedule_with_warmup 13 | import random 14 | from utils import load_model,get_img_patch_token_size 15 | import pickle 16 | 17 | def set_seed(args): 18 | seed = args.seed 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | 23 | def eval_loss(model, loss_function, valid_reader, device): 24 | model.eval() 25 | total_loss = 0.0 26 | total_corr = 0.0 27 | counter = 0.0 28 | for step, batch in tqdm(enumerate(valid_reader)): 29 | with torch.no_grad(): 30 | batch_size=batch['query_image_inputs'].size(0) 31 | 32 | query_embeddings = model(batch['query_image_inputs'].cuda(),None,device) 33 | candidate_embeddings = model(batch['pos_image_inputs'].cuda(),batch['pos_report_inputs'],device) 34 | 35 | query_embeddings = F.normalize(query_embeddings, dim=-1) 36 | candidate_embeddings = F.normalize(candidate_embeddings, dim=-1) 37 | logit_scale = model.logit_scale.exp() 38 | score = torch.matmul(query_embeddings, candidate_embeddings.t())* logit_scale 39 | target = torch.arange(batch_size, dtype=torch.long).cuda() 40 | loss = loss_function(score, target) 41 | max_score, max_idxs = torch.max(score, 1) 42 | 43 | correct_predictions_count = (max_idxs == target).sum() / batch_size 44 | total_corr += correct_predictions_count.item() 45 | total_loss += loss.item() 46 | counter += 1 47 | 48 | if counter == 0: 49 | return 0.0, 0.0 50 | return total_loss / counter, total_corr / counter 51 | 52 | def train(train_reader, valid_reader, model, device): 53 | t_total = len(train_reader) // args.gradient_accumulation_steps * args.num_train_epochs 54 | eval_step = t_total//args.num_train_epochs 55 | 56 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 57 | include = lambda n, p: not exclude(n, p) 58 | 59 | named_parameters = list(model.named_parameters()) 60 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 61 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 62 | 63 | optimizer = optim.AdamW( 64 | [ 65 | {"params": gain_or_bias_params, "weight_decay": 0.}, 66 | {"params": rest_params, "weight_decay": 0.2}, 67 | ], 68 | lr=args.learning_rate, 69 | betas=(0.9, 0.98), 70 | eps=1.0e-6, 71 | ) 72 | scheduler = get_cosine_schedule_with_warmup( 73 | optimizer, num_warmup_steps=int(args.warmup_steps*t_total), num_training_steps=t_total 74 | ) 75 | loss_function = torch.nn.CrossEntropyLoss() 76 | tag, global_step, global_loss, best_acc = 0, 0, 0.0, 0.0 77 | model.zero_grad() 78 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 79 | for epoch in train_iterator: 80 | epoch_iterator = tqdm(train_reader) 81 | for step, batch in enumerate(epoch_iterator): 82 | model.train() 83 | batch_size=batch['query_image_inputs'].size(0) 84 | 85 | query_embeddings = model(batch['query_image_inputs'].cuda(),None,device) 86 | candidate_embeddings = model(batch['pos_image_inputs'].cuda(),batch['pos_report_inputs'],device) 87 | 88 | query_embeddings = F.normalize(query_embeddings, dim=-1) 89 | candidate_embeddings = F.normalize(candidate_embeddings, dim=-1) 90 | logit_scale = model.logit_scale.exp() 91 | score = torch.matmul(query_embeddings, candidate_embeddings.t())* logit_scale 92 | target = torch.arange(batch_size, dtype=torch.long).cuda() 93 | loss = loss_function(score, target) 94 | max_score, max_idxs = torch.max(score, 1) 95 | correct_predictions_acc = (max_idxs == target).sum() / batch_size 96 | 97 | if args.gradient_accumulation_steps > 1: 98 | loss = loss / args.gradient_accumulation_steps 99 | loss.backward() 100 | 101 | global_loss += loss.item() 102 | 103 | if (step + 1) % args.gradient_accumulation_steps == 0: 104 | global_step += 1 105 | optimizer.step() 106 | scheduler.step() 107 | model.zero_grad() 108 | wandb.log( 109 | { 110 | "training_loss":global_loss / global_step, 111 | "training_acc":correct_predictions_acc, 112 | "learning_rate": optimizer.param_groups[0]["lr"] 113 | } 114 | ) 115 | epoch_iterator.set_description(f"Loss:{global_loss / global_step}") 116 | 117 | if global_step % eval_step == 0 and global_step > 0: 118 | dev_loss, dev_acc = eval_loss(model, loss_function, valid_reader, device) 119 | print( 120 | "Evaluation at global step {}, average dev loss: {:.4f},average dev acc: {:.4f}".format( 121 | global_step, dev_loss, dev_acc)) 122 | wandb.log( 123 | { 124 | "dev_loss":dev_loss, 125 | "dev_acc":dev_acc, 126 | } 127 | ) 128 | if best_acc <= dev_acc: 129 | best_acc = dev_acc 130 | saved_path = args.out_path+f"_{epoch}" 131 | torch.save({'epoch': epoch, 132 | 'model': model.state_dict()}, args.out_path) 133 | print("Saved best epoch {0}, best acc {1}".format(epoch, best_acc)) 134 | tag = 0 135 | else: 136 | tag += 1 137 | if tag >= args.early_stop: 138 | print('*********early stop**********') 139 | exit() 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser("") 144 | parser.add_argument("--out_path", type=str) 145 | parser.add_argument("--train_path", type=str) 146 | parser.add_argument("--valid_path", type=str) 147 | parser.add_argument("--train_pos_path", type=str) 148 | parser.add_argument("--valid_pos_path", type=str) 149 | parser.add_argument("--wandb_name", type=str) 150 | 151 | parser.add_argument("--t5_model_name", type=str, default='OpenMatch/t5-ance') 152 | parser.add_argument("--clip_model_name",type=str,default='openai/clip-vit-base-patch32') 153 | parser.add_argument("--pretrained_model_path", type=str) 154 | 155 | 156 | parser.add_argument("--num_workers", type=int, default=8) 157 | parser.add_argument("--seed", type=int, default=42) 158 | parser.add_argument("--early_stop", type=int, default=5) 159 | parser.add_argument("--train_batch_size", type=int, default=32) 160 | parser.add_argument("--valid_batch_size", type=int, default=32) 161 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 162 | parser.add_argument("--num_train_epochs", type=int, default=15) 163 | parser.add_argument("--learning_rate", type=float, default=5e-6) 164 | parser.add_argument("--warmup_steps", type=int, default=0.1) 165 | args = parser.parse_args() 166 | 167 | 168 | set_seed(args) 169 | 170 | 171 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 | 173 | 174 | wandb.init(name=args.wandb_name, 175 | sync_tensorboard=True) 176 | 177 | wandb.config.update(args) 178 | 179 | 180 | train_data = json.load(open(args.train_path,"r")) 181 | valid_data = json.load(open(args.valid_path,"r")) 182 | train_pos_data = pickle.load(open(args.train_pos_path,'rb')) 183 | valid_pos_data = pickle.load(open(args.valid_pos_path,'rb')) 184 | 185 | 186 | tokenizer, model, image_processor = load_model(args,device) 187 | model.to(device) 188 | 189 | 190 | img_patch_token_size=get_img_patch_token_size(args.clip_model_name) 191 | 192 | train_dataset = MedDataset(image_processor, tokenizer, train_data,train_pos_data,img_patch_token_size) 193 | valid_dataset = MedDataset(image_processor, tokenizer, train_data,valid_pos_data,img_patch_token_size,valid_data) 194 | 195 | 196 | 197 | train_sampler = RandomSampler(train_dataset) 198 | valid_sampler = SequentialSampler(valid_dataset) 199 | 200 | traindata_reader = DataLoader(dataset=train_dataset, sampler=train_sampler, num_workers=args.num_workers, 201 | batch_size=args.train_batch_size, collate_fn=train_dataset.Collector, drop_last=True) 202 | validdata_reader = DataLoader(dataset=valid_dataset, sampler=valid_sampler, num_workers=args.num_workers, 203 | batch_size=args.valid_batch_size, collate_fn=valid_dataset.Collector, drop_last=False) 204 | if args.pretrained_model_path != None: 205 | model.load_state_dict(torch.load(args.pretrained_model_path)['model'],strict=False) 206 | model.cuda() 207 | train(traindata_reader, validdata_reader, model, device) 208 | -------------------------------------------------------------------------------- /src/retriever/ANCE/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm,trange 5 | import torch 6 | import argparse 7 | from torch import optim 8 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 9 | from data import MedDataset 10 | import torch.nn.functional as F 11 | import wandb 12 | from transformers import get_cosine_schedule_with_warmup 13 | import random 14 | from utils import load_model,get_img_patch_token_size 15 | import pickle 16 | from IPython import embed 17 | 18 | 19 | def set_seed(args): 20 | seed = args.seed 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | 25 | def eval_loss(model, loss_function, valid_reader, device): 26 | model.eval() 27 | total_loss = 0.0 28 | total_corr = 0.0 29 | counter = 0.0 30 | for step, batch in tqdm(enumerate(valid_reader)): 31 | with torch.no_grad(): 32 | batch_size=batch['query_image_inputs'].size(0) 33 | 34 | query_embeddings = model(batch['query_image_inputs'].cuda(),None,device) 35 | candidate_embeddings = model(batch['pos_neg_image_inputs'].cuda(),batch['pos_neg_report_inputs'],device) 36 | 37 | query_embeddings = F.normalize(query_embeddings, dim=-1) 38 | candidate_embeddings = F.normalize(candidate_embeddings, dim=-1) 39 | logit_scale = model.logit_scale.exp() 40 | score = torch.matmul(query_embeddings, candidate_embeddings.t())* logit_scale 41 | target = batch['targets'].cuda() 42 | 43 | loss = loss_function(score, target) 44 | max_score, max_idxs = torch.max(score, 1) 45 | 46 | correct_predictions_count = (max_idxs == target).sum() / batch_size 47 | total_corr += correct_predictions_count.item() 48 | total_loss += loss.item() 49 | counter += 1 50 | 51 | if counter == 0: 52 | return 0.0, 0.0 53 | return total_loss / counter, total_corr / counter 54 | 55 | def train(train_reader, valid_reader, model, device): 56 | t_total = len(train_reader) // args.gradient_accumulation_steps * args.num_train_epochs 57 | eval_step = t_total//args.num_train_epochs 58 | 59 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 60 | include = lambda n, p: not exclude(n, p) 61 | 62 | named_parameters = list(model.named_parameters()) 63 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 64 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 65 | 66 | optimizer = optim.AdamW( 67 | [ 68 | {"params": gain_or_bias_params, "weight_decay": 0.}, 69 | {"params": rest_params, "weight_decay": 0.2}, 70 | ], 71 | lr=args.learning_rate, 72 | betas=(0.9, 0.98), 73 | eps=1.0e-6, 74 | ) 75 | scheduler = get_cosine_schedule_with_warmup( 76 | optimizer, num_warmup_steps=int(args.warmup_steps*t_total), num_training_steps=t_total 77 | ) 78 | loss_function = torch.nn.CrossEntropyLoss() 79 | tag, global_step, global_loss, best_acc = 0, 0, 0.0, 0.0 80 | model.zero_grad() 81 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 82 | for epoch in train_iterator: 83 | epoch_iterator = tqdm(train_reader) 84 | for step, batch in enumerate(epoch_iterator): 85 | model.train() 86 | batch_size=batch['query_image_inputs'].size(0) 87 | 88 | query_embeddings = model(batch['query_image_inputs'].cuda(),None,device) 89 | candidate_embeddings = model(batch['pos_neg_image_inputs'].cuda(),batch['pos_neg_report_inputs'],device) 90 | 91 | query_embeddings = F.normalize(query_embeddings, dim=-1) 92 | candidate_embeddings = F.normalize(candidate_embeddings, dim=-1) 93 | logit_scale = model.logit_scale.exp() 94 | score = torch.matmul(query_embeddings, candidate_embeddings.t())* logit_scale 95 | target = batch['targets'].cuda() 96 | 97 | loss = loss_function(score, target) 98 | max_score, max_idxs = torch.max(score, 1) 99 | correct_predictions_acc = (max_idxs == target).sum() / batch_size 100 | 101 | if args.gradient_accumulation_steps > 1: 102 | loss = loss / args.gradient_accumulation_steps 103 | loss.backward() 104 | 105 | global_loss += loss.item() 106 | 107 | if (step + 1) % args.gradient_accumulation_steps == 0: 108 | global_step += 1 109 | optimizer.step() 110 | scheduler.step() 111 | model.zero_grad() 112 | wandb.log( 113 | { 114 | "training_loss":global_loss / global_step, 115 | "training_acc":correct_predictions_acc, 116 | "learning_rate": optimizer.param_groups[0]["lr"] 117 | } 118 | ) 119 | epoch_iterator.set_description(f"Loss:{global_loss / global_step}") 120 | 121 | if global_step % eval_step == 0 and global_step > 0: 122 | dev_loss, dev_acc = eval_loss(model, loss_function, valid_reader, device) 123 | print( 124 | "Evaluation at global step {}, average dev loss: {:.4f},average dev acc: {:.4f}".format( 125 | global_step, dev_loss, dev_acc)) 126 | wandb.log( 127 | { 128 | "dev_loss":dev_loss, 129 | "dev_acc":dev_acc, 130 | } 131 | ) 132 | 133 | if best_acc <= dev_acc: 134 | best_acc = dev_acc 135 | torch.save({'epoch': epoch, 136 | 'model': model.state_dict()}, args.out_path) 137 | print("Saved best epoch {0}, best acc {1}".format(epoch, best_acc)) 138 | tag = 0 139 | else: 140 | tag += 1 141 | if tag >= args.early_stop: 142 | print('*********early stop**********') 143 | exit() 144 | 145 | 146 | if __name__ == '__main__': 147 | parser = argparse.ArgumentParser("") 148 | parser.add_argument("--out_path", type=str) 149 | parser.add_argument("--train_path", type=str) 150 | parser.add_argument("--valid_path", type=str) 151 | parser.add_argument("--train_pos_path", type=str) 152 | parser.add_argument("--train_neg_path", type=str) 153 | parser.add_argument("--valid_pos_path", type=str) 154 | parser.add_argument("--valid_neg_path", type=str) 155 | 156 | parser.add_argument("--wandb_name", type=str) 157 | 158 | parser.add_argument("--t5_model_name", type=str, default='OpenMatch/t5-ance') 159 | parser.add_argument("--clip_model_name",type=str,default='openai/clip-vit-base-patch32') 160 | parser.add_argument("--pretrained_model_path", type=str) 161 | 162 | parser.add_argument("--num_workers", type=int, default=8) 163 | parser.add_argument("--seed", type=int, default=42) 164 | parser.add_argument("--early_stop", type=int, default=3) 165 | parser.add_argument("--train_batch_size", type=int, default=8) 166 | parser.add_argument("--valid_batch_size", type=int, default=8) 167 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 168 | parser.add_argument("--num_train_epochs", type=int, default=3) 169 | parser.add_argument("--learning_rate", type=float, default=1e-6) 170 | parser.add_argument("--warmup_steps", type=int, default=0.1) 171 | 172 | 173 | args = parser.parse_args() 174 | 175 | 176 | set_seed(args) 177 | 178 | 179 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 180 | 181 | 182 | wandb.init( 183 | name=args.wandb_name, 184 | sync_tensorboard=True) 185 | 186 | wandb.config.update(args) 187 | 188 | 189 | train_data = json.load(open(args.train_path,"r")) 190 | valid_data = json.load(open(args.valid_path,"r")) 191 | train_pos_data = pickle.load(open(args.train_pos_path,'rb')) 192 | train_neg_data = pickle.load(open(args.train_neg_path,'rb')) 193 | valid_pos_data = pickle.load(open(args.valid_pos_path,'rb')) 194 | valid_neg_data = pickle.load(open(args.valid_neg_path,'rb')) 195 | 196 | tokenizer, model, image_processor = load_model(args,device) 197 | model.to(device) 198 | 199 | 200 | img_patch_token_size=get_img_patch_token_size(args.clip_model_name) 201 | 202 | train_dataset = MedDataset(image_processor, tokenizer, train_data,train_pos_data,train_neg_data,img_patch_token_size) 203 | valid_dataset = MedDataset(image_processor, tokenizer, train_data,valid_pos_data,valid_neg_data,img_patch_token_size,valid_data) 204 | 205 | 206 | 207 | train_sampler = RandomSampler(train_dataset) 208 | valid_sampler = SequentialSampler(valid_dataset) 209 | 210 | traindata_reader = DataLoader(dataset=train_dataset, sampler=train_sampler, num_workers=args.num_workers, 211 | batch_size=args.train_batch_size, collate_fn=train_dataset.Collector, drop_last=True) 212 | validdata_reader = DataLoader(dataset=valid_dataset, sampler=valid_sampler, num_workers=args.num_workers, 213 | batch_size=args.valid_batch_size, collate_fn=valid_dataset.Collector, drop_last=False) 214 | if args.pretrained_model_path != None: 215 | model.load_state_dict(torch.load(args.pretrained_model_path)['model'],strict=False) 216 | model.cuda() 217 | train(traindata_reader, validdata_reader, model, device) 218 | --------------------------------------------------------------------------------