├── .gitignore ├── 0.download_data.sh ├── 1.data_process_test&dev.sh ├── 2.data_process_train.sh ├── 3.single_node_train_gemma.sh ├── 4.eval.sh ├── LICENSE.txt ├── README.md ├── assets ├── apollo_medium_final.png ├── dataset.png ├── final.png └── result.png ├── metadata ├── dev.json ├── dev │ ├── ar.json │ ├── en.json │ ├── es.json │ ├── fr.json │ ├── hi.json │ └── zh.json ├── merge_json_train.py ├── test.json └── test │ ├── ar.json │ ├── en.json │ ├── es.json │ ├── fr.json │ ├── hi.json │ └── zh.json ├── requirements.txt ├── scripts ├── 3.multinode_train_gema7B_rank0.sh ├── 3.multinode_train_gema7B_rank1.sh └── 3.multinode_train_gema7B_rank2.sh ├── src ├── evaluate │ ├── cli_demo.py │ ├── eval_72b_34b.py │ ├── eval_gemma.py │ ├── eval_huatuo2.py │ ├── eval_llama2.py │ ├── eval_llama70b.py │ ├── eval_meditron.py │ ├── eval_meditron70b.py │ ├── eval_mistral.py │ ├── eval_mmedlm2.py │ ├── eval_qwen.py │ ├── eval_yi.py │ ├── eval_zephyr.py │ └── generate_score.py ├── process │ ├── openai_rewrite │ │ ├── OpenAIGPT.py │ │ ├── OpenAIGPT_datagen_multithread.py │ │ ├── gpt_key.txt │ │ ├── guidelines_en │ │ │ ├── 1.2.prepare_data.py │ │ │ ├── 1.prepare_data.py │ │ │ ├── 1.run_prepare_data.sh │ │ │ ├── 2.run_gpt_datagen_multithread.sh │ │ │ ├── 3.extract.py │ │ │ └── data │ │ │ │ └── 1.dev.jsonl │ │ └── patient_en │ │ │ ├── 1.2.prepare_data.py │ │ │ ├── 1.prepare_data.py │ │ │ ├── 1.run_prepare_data.sh │ │ │ ├── 2.run_gpt_datagen_multithread.sh │ │ │ ├── 3.extract.py │ │ │ └── data │ │ │ └── 1.dev.jsonl │ └── prepare │ │ ├── data_process_test_gemma.py │ │ ├── data_process_test_huatuo2.py │ │ ├── data_process_test_llama.py │ │ ├── data_process_test_meditron.py │ │ ├── data_process_test_mistral.py │ │ ├── data_process_test_qwen.py │ │ ├── data_process_test_yi.py │ │ ├── data_process_test_zephyr.py │ │ ├── data_process_train_gemma.py │ │ ├── data_process_train_qwen.py │ │ └── data_process_train_yi.py ├── proxy-tuning │ ├── eval │ │ ├── apollodata │ │ │ └── run_eval.py │ │ └── utils.py │ ├── modeling │ │ └── mexperts.py │ └── scripts │ │ └── eval │ │ └── proxy_tuning.sh └── sft │ ├── train_gemma_resume_val.py │ ├── train_qwen_resume_val.py │ ├── train_yi_resume_val.py │ └── training_config │ ├── zero.yaml │ └── zero_multi.yaml └── utils ├── check.ipynb └── kill.sh /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/.gitignore -------------------------------------------------------------------------------- /0.download_data.sh: -------------------------------------------------------------------------------- 1 | # download ApolloCorpus 2 | 3 | cd metadata 4 | wget https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/resolve/main/ApolloCorpus.zip 5 | unzip ApolloCorpus.zip 6 | 7 | # Prepare Data for Mix training 8 | mkdir mixTrain 9 | 10 | 11 | cd train/pretrain 12 | # Mixtraining Only use QA pairs in Pretrain 13 | for file in *; do 14 | if [[ $file == *_qa.json ]]; then 15 | cp "$file" "../mixTrain/" 16 | fi 17 | done 18 | cd ../ 19 | 20 | # copy all file from sft to mix_train 21 | mv sft/* mixTrain/ 22 | 23 | # merge all the file from mix_train directory to json 24 | python merge_json_train.py 25 | cd ../ 26 | 27 | 28 | -------------------------------------------------------------------------------- /1.data_process_test&dev.sh: -------------------------------------------------------------------------------- 1 | # Take gemma as example, other models' python code is in ./src/process/prepare/data_process_test_{model}.py 2 | mkdir -p ./data/gemma 3 | 4 | python ./src/process/prepare/data_process_test_gemma.py \ 5 | --data_path ./metadata/test.json \ 6 | --few_shot 3 \ 7 | --save_path ./data/gemma/test.json 8 | 9 | 10 | python ./src/process/prepare/data_process_test_gemma.py \ 11 | --data_path ./metadata/dev.json \ 12 | --few_shot 3 \ 13 | --save_path ./data/gemma/dev.json 14 | -------------------------------------------------------------------------------- /2.data_process_train.sh: -------------------------------------------------------------------------------- 1 | # need change 4 place 2 | # Please set the wandb key in the python file (e.g ./src/process/prepare/data_process_train_gemma.py) 3 | 4 | mkdir wandb_logs 5 | 6 | experiment_name=Gemma_MixTrain_Data 7 | log_folder="./logs/${experiment_name}" 8 | mkdir -p $log_folder 9 | log_name=$(date +"%m-%d_%H-%M").log 10 | 11 | 12 | python ./src/process/prepare/data_process_train_gemma.py \ 13 | --data_path ./metadata/train/mixTrain.json \ 14 | --model_path /your/path/to/gemma-2b \ 15 | --wandb_log ./wandb_logs \ 16 | --experiment_name ${experiment_name} \ 17 | --save_path ./data/Gemma/mixTrain > ${log_folder}/$log_name 2>&1 & -------------------------------------------------------------------------------- /3.single_node_train_gemma.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python *.py 3 | 4 | # Please set the wandb key in the python file (e.g ./src/sft/train_gemma_resume_val.py) 5 | process_port=29502 6 | experiment_name=Gemma2b_MixTrain_Train 7 | model_dir=/your/path/to/gemma-2b 8 | # ckpt_dir= 9 | train_data_file=./data/gemma/MixTrain 10 | dev_data_file=./data/gemma/dev.json 11 | output_dir=./ckpts 12 | log_folder="./logs/${experiment_name}" 13 | mkdir -p $log_folder 14 | log_name=$(date +"%m-%d_%H-%M").log 15 | 16 | accelerate launch \ 17 | --config_file ./src/sft/training_config/zero.yaml \ 18 | --num_processes 8 \ 19 | --num_machines 1 \ 20 | --main_process_port ${process_port} \ 21 | --num_cpu_threads_per_process 8 \ 22 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \ 23 | --model_path ${model_dir} \ 24 | --experiment_name ${experiment_name} \ 25 | --gradient_accumulation_steps 8 \ 26 | --train_data_dir ${train_data_file} \ 27 | --dev_data_dir ${dev_data_file} \ 28 | --output_dir ${output_dir} \ 29 | --log_dir ./wandb_logs \ 30 | --n_epochs 1 \ 31 | --train_bsz_per_gpu 2 \ 32 | --eval_bsz_per_gpu 2 \ 33 | --learning_rate 1e-5 \ 34 | --eval_step -1 \ 35 | --save_step -1 \ 36 | --warmup_rates 0.03 \ 37 | --max_ckpts 5 \ 38 | --gradient_checkpointing > ${log_folder}/$log_name 2>&1 & 39 | 40 | -------------------------------------------------------------------------------- /4.eval.sh: -------------------------------------------------------------------------------- 1 | experiment_name=Gemma2b_MixTrain_Test 2 | log_folder="./logs/${experiment_name}" 3 | result_folder="./results/${experiment_name}" 4 | mkdir -p $log_folder 5 | mkdir -p $result_folder 6 | log_name=$(date +"%m-%d_%H-%M").log 7 | 8 | CUDA_LAUNCH_BLOCKING=1 accelerate launch --main_process_port 23035 ./src/evaluate/eval_gemma.py \ 9 | --model_path=./ckpts/gemma-2b_MixTrain \ 10 | --input_path=./data/gemma/test.json \ 11 | --output_path=${result_folder}/model_ans.jsonl \ 12 | --score_path=${result_folder}/score.json \ 13 | --wrong_item_path=${result_folder}/wrong_item.json > ${log_folder}/$log_name 2>&1 & 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multilingual Medicine: Model, Dataset, Benchmark, Code 2 | 3 | Covering English, Chinese, French, Hindi, Spanish, Hindi, Arabic So far 4 | 5 | 6 |

7 | 📃 Paper • 🌐 Demo • 🤗 ApolloCorpus • 🤗 XMedBench • 🌐 ApolloMoE 8 |
中文 | English 9 |

10 | 11 | ![Apollo](assets/apollo_medium_final.png) 12 | 13 | ## 🌈 Update 14 | 15 | * **[2024.10.15]** [ApolloMoE](https://github.com/FreedomIntelligence/ApolloMoE) repo released, covering 50 Languages. 16 | * **[2024.04.25]** [MedJamba](https://huggingface.co/FreedomIntelligence/Apollo-MedJamba) released, train and evaluation code refer to [repo](https://github.com/FreedomIntelligence/MedJamba). 17 | * **[2024.03.07]** [Paper](https://arxiv.org/abs/2403.03640) released. 18 | * **[2024.02.12]** ApolloCorpus and XMedBench is published!🎉 19 | * **[2024.01.23]** Apollo repo is published!🎉 20 | 21 | 22 | ## Results 23 | 🤗 Apollo-0.5B • 🤗 Apollo-1.8B • 🤗 Apollo-2B • 🤗 Apollo-6B • 🤗 Apollo-7B • 🤗 Apollo-34B • 🤗 Apollo-72B 24 | 25 | 🤗 MedJamba 26 | 27 | 🤗 Apollo-0.5B-GGUF • 🤗 Apollo-2B-GGUF • 🤗 Apollo-6B-GGUF • 🤗 Apollo-7B-GGUF 28 | 29 | 30 | 31 | ![Apollo](assets/result.png) 32 | 33 | 34 | 35 | ## Usage Format 36 | 37 | - 0.5B, 1.8B, 2B, 6B, 7B: User:{query}\nAssistant:{response}<|endoftext|> 38 | - 34B, 72B: <|User|>:{query}\n<|Assistant|>:{response}<|endoftext|> 39 | 40 | ## Dataset & Evaluation 41 | 42 | - Dataset 43 | 🤗 ApolloCorpus 44 | 45 |
Click to expand 46 | 47 | ![Apollo](assets/dataset.png) 48 | 49 | - [Zip File](https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/blob/main/ApolloCorpus.zip) 50 | - [Data category](https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/tree/main/train) 51 | - Pretrain: 52 | - data item: 53 | - json_name: {data_source}_{language}_{data_type}.json 54 | - data_type: medicalBook, medicalGuideline, medicalPaper, medicalWeb(from online forum), medicalWiki 55 | - language: en(English), zh(chinese), es(spanish), fr(french), hi(Hindi) 56 | - data_type: qa(generated qa from text) 57 | - data_type==text: list of string 58 | ``` 59 | [ 60 | "string1", 61 | "string2", 62 | ... 63 | ] 64 | ``` 65 | - data_type==qa: list of qa pairs(list of string) 66 | ``` 67 | [ 68 | [ 69 | "q1", 70 | "a1", 71 | "q2", 72 | "a2", 73 | ... 74 | ], 75 | ... 76 | ] 77 | ``` 78 | - SFT: 79 | - json_name: {data_source}_{language}.json 80 | - data_type: code, general, math, medicalExam, medicalPatient 81 | - data item: list of qa pairs(list of string) 82 | ``` 83 | [ 84 | [ 85 | "q1", 86 | "a1", 87 | "q2", 88 | "a2", 89 | ... 90 | ], 91 | ... 92 | ] 93 | ``` 94 | 95 | 96 |
97 | 98 | - Evaluation 99 | 🤗
XMedBench 100 | 101 |
Click to expand 102 | 103 | - EN: 104 | - [MedQA-USMLE](https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options) 105 | - [MedMCQA](https://huggingface.co/datasets/medmcqa/viewer/default/test) 106 | - [PubMedQA](https://huggingface.co/datasets/pubmed_qa): Because the results fluctuated too much, they were not used in the paper. 107 | - [MMLU-Medical](https://huggingface.co/datasets/cais/mmlu) 108 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine 109 | - ZH: 110 | - [MedQA-MCMLE](https://huggingface.co/datasets/bigbio/med_qa/viewer/med_qa_zh_4options_bigbio_qa/test) 111 | - [CMB-single](https://huggingface.co/datasets/FreedomIntelligence/CMB): Not used in the paper 112 | - Randomly sample 2,000 multiple-choice questions with single answer. 113 | - [CMMLU-Medical](https://huggingface.co/datasets/haonan-li/cmmlu) 114 | - Anatomy, Clinical_knowledge, College_medicine, Genetics, Nutrition, Traditional_chinese_medicine, Virology 115 | - [CExam](https://github.com/williamliujl/CMExam): Not used in the paper 116 | - Randomly sample 2,000 multiple-choice questions 117 | 118 | 119 | - ES: [Head_qa](https://huggingface.co/datasets/head_qa) 120 | - FR: [Frenchmedmcqa](https://github.com/qanastek/FrenchMedMCQA) 121 | - HI: [MMLU_HI](https://huggingface.co/datasets/FreedomIntelligence/MMLU_Arabic) 122 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine 123 | - AR: [MMLU_Ara](https://huggingface.co/datasets/FreedomIntelligence/MMLU_Hindi) 124 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine 125 | 126 | 127 |
128 | 129 | 130 | ## Results reproduction 131 |
Click to expand 132 | 133 | 134 | We take Gemma-2b as example 135 | 1. Download Dataset for project: 136 | 137 | ``` 138 | bash 0.download_data.sh 139 | ``` 140 | 141 | 2. Prepare test and dev for specific model: 142 | 143 | 144 | - Create test data for with special token, you can use ./util/check.ipynb to check models' special tokens 145 | 146 | ``` 147 | bash 1.data_process_test&dev.sh 148 | ``` 149 | 150 | 3. Prepare train data for specific model (Create tokenized data in advance): 151 | 152 | 153 | - You can adjust data Training order and Training Epoch in this step 154 | 155 | ``` 156 | bash 2.data_process_train.sh 157 | ``` 158 | 159 | 4. Train the model 160 | 161 | 162 | - If you want to train in Multi Nodes please refer to ./scripts/multi_node_train_*.sh 163 | 164 | 165 | 166 | 167 | ``` 168 | bash 3.single_node_train_gemma.sh 169 | ``` 170 | 171 | 5. (Optional) Proxy-Tuning: Directly improve model capabilities without fine-tuning 172 | 173 | ``` 174 | bash src/proxy-tuning/scripts/eval/proxy_tuning.sh 175 | ``` 176 | 6. Evaluate your model: Generate score for benchmark 177 | 178 | ``` 179 | bash 4.eval.sh 180 | ``` 181 | 182 | 7. Evaluate your model: Play with your ckpts in bash 183 | 184 | ``` 185 | python ./src/evaluate/cli_demo.py --model_name='./ckpts/your/path/tfmr' 186 | ``` 187 | 188 |
189 | 190 | 191 | ## Acknowledgment 192 | 193 | - [HuatuoGPT-II](https://github.com/FreedomIntelligence/HuatuoGPT-II) 194 | - [proxy-tuning](https://github.com/alisawuffles/proxy-tuning) 195 | 196 | ## Citation 197 | Please use the following citation if you intend to use our dataset for training or evaluation: 198 | 199 | ``` 200 | @misc{wang2024apollo, 201 | title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People}, 202 | author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang}, 203 | year={2024}, 204 | eprint={2403.03640}, 205 | archivePrefix={arXiv}, 206 | primaryClass={cs.CL} 207 | } 208 | ``` 209 | 210 | ## Star History 211 | 212 | 213 | 214 | 215 | 216 | Star History Chart 217 | 218 | 219 | -------------------------------------------------------------------------------- /assets/apollo_medium_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/apollo_medium_final.png -------------------------------------------------------------------------------- /assets/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/dataset.png -------------------------------------------------------------------------------- /assets/final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/final.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/result.png -------------------------------------------------------------------------------- /metadata/dev/fr.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "Parmi les propositions suivantes, donner celle qui est exacte. En spectrophotométrie d'adsorption moléculaire, une absorbance de 1 correspond à une absorption du faisceau incident par le composé présent dans la cuve de :", 4 | "options": "(A) 100 %.\n(B) 90 %.\n(C) 50 %.\n(D) 40 %.\n(E) 10 %.", 5 | "answer": "(B)", 6 | "source": "frenchmedmcqa" 7 | }, 8 | { 9 | "question": "L'exsanguino transfusion est indiqué dans un type d'intoxication aiguë, lequel?", 10 | "options": "(A) Méthanol.\n(B) Cyanures.\n(C) Agents méthémoglobinisants.\n(D) Métaux lourds.\n(E) Monoxyde de carbone.", 11 | "answer": "(C)", 12 | "source": "frenchmedmcqa" 13 | }, 14 | { 15 | "question": "Parmi les propositions suivantes, une seule est exacte, Laquelle? Un complexe zinc-EDTA est dosé par spectrophotométrie d'absorption atomique. Sous quelle forme?", 16 | "options": "(A) Atome de zinc.\n(B) Ion zinc.\n(C) Radical.\n(D) Hydroxyde de zinc.\n(E) Oxyde de zinc.", 17 | "answer": "(A)", 18 | "source": "frenchmedmcqa" 19 | }, 20 | { 21 | "question": "Parmi les propositions suivantes, laquelle est exacte ? Les hétérosides cardiotoniques exercent un effet inotrope positif:", 22 | "options": "(A) Par une stimulation des récepteurs bêta-adrénergiques cardiaques.\n(B) Par un effet parasympatholytique.\n(C) Par une inhibition de la Na+/K+ ATPase et une modification secondaire des mouvements du calcium.\n(D) Par une inhibition de l'entrée du calcium dans la cellule cardiaque.\n(E) Par un effet vasodilatateur.", 23 | "answer": "(C)", 24 | "source": "frenchmedmcqa" 25 | }, 26 | { 27 | "question": "En chromatographie en phase liquide, l'amélioration de la résolution entre 2 composés dépend des paramètres suivants sauf un. Lequel?", 28 | "options": "(A) Le nombre de plateaux théoriques.\n(B) Le facteur de capacité du composé le plus retenu.\n(C) Le facteur de sélectivité.\n(D) La vitesse de déroulement du papier.\n(E) La longueur de la colonne.", 29 | "answer": "(D)", 30 | "source": "frenchmedmcqa" 31 | }, 32 | { 33 | "question": "Lors de l'examen direct du LCR, la découverte de diplocoques à Gram négatif doit faire évoquer le germe suivant. Donner la réponse exacte.", 34 | "options": "(A) Streptococcus pneumoniae.\n(B) Escherichia coli.\n(C) Neisseria meningitidis.\n(D) Cryptococcus neoformans.\n(E) Haemophilus influenzae.", 35 | "answer": "(C)", 36 | "source": "frenchmedmcqa" 37 | }, 38 | { 39 | "question": "Parmi ces propositions concernant la sérotonine, une seule est fausse, laquelle?", 40 | "options": "(A) C'est une amine biogène endogène.\n(B) Elle franchit la barrière hémato-encéphalique.\n(C) Elle est synthétisée à partir du tryptophane alimentaire.\n(D) Elle n'est pas synthétisée dans les plaquettes.\n(E) Son métabolite principal est constitué par le 5 HlAA.", 41 | "answer": "(B)", 42 | "source": "frenchmedmcqa" 43 | }, 44 | { 45 | "question": "Parmi ces propositions, une seule est fausse, laquelle?", 46 | "options": "(A) L'émission de lumière après excitation d'une molécule par radiations s'appelle fluorescence.\n(B) La raie de diffusion Rayleigh est un artefact de la fluorescence moléculaire.\n(C) Un spectrofluorimètre possède deux monochromateurs.\n(D) la phosphorescence est l'émission de lumière de molécules contenant des atomes de phosphore.\n(E) La diffusion Raman est un artefact de la fluorescence moléculaire.", 47 | "answer": "(D)", 48 | "source": "frenchmedmcqa" 49 | }, 50 | { 51 | "question": "Parmi les propositions suivantes, indiquer celle qui correspond le mieux à la définition d'une clairance:", 52 | "options": "(A) Volume de plasma contenant une substance donnée, filtré par unité de temps.\n(B) Volume de plasma épuré en une substance donnée par unité de temps.\n(C) Volume d'urine contenant la même quantité d'une substance que 1 mL de plasma.\n(D) Volume de plasma contenant une substance donnée, passant par le rein par unité de temps.\n(E) Volume d'urine d'où une substance peut être réabsorbée par unité de temps.", 53 | "answer": "(B)", 54 | "source": "frenchmedmcqa" 55 | }, 56 | { 57 | "question": "Quel(s) élément(s) ne peut (peuvent) pas être dosé(s) par photométrie de flamme ?", 58 | "options": "(A) Aluminium.\n(B) Sodium.\n(C) Potassium.\n(D) Césium.\n(E) Lithium.", 59 | "answer": "(A)", 60 | "source": "frenchmedmcqa" 61 | }, 62 | { 63 | "question": "Parmi les propositions suivantes, quelle est celle qui s'applique à Clostridium perfringens?", 64 | "options": "(A) C'est une bactérie qui peut former des spores.\n(B) C'est une bactérie aéro-anaérobie facultative.\n(C) Il n'est jamais responsable de toxi-infections alimentaires.\n(D) C'est un agent fréquent d'infections urinaires.\n(E) Il est toujours sensible aux aminosides.", 65 | "answer": "(A)", 66 | "source": "frenchmedmcqa" 67 | }, 68 | { 69 | "question": "En 1998, sur 85.000 hommes de 60 à 64 ans, 300 cas d'infarctus du myocarde ont été dénombrés dont 270 nouveaux. Parmi les nouveaux cas, 135 sont décédés dans l'année. Le nombre total de décès dus à cette maladie pour l'année est de 150; Parmi les résultats ci-dessous, indiquer celui qui représente le taux d'incidence:", 70 | "options": "(A) 135/85.000.\n(B) 300/85.000.\n(C) 150/85.000.\n(D) 270 /85.000.\n(E) 135/270.", 71 | "answer": "(D)", 72 | "source": "frenchmedmcqa" 73 | }, 74 | { 75 | "question": "Parmi les unités de radioactivité suivantes, indiquer celle qui est utilisée pour exprimer l'équivalent de dose (H) :", 76 | "options": "(A) Becquerel.\n(B) Curie.\n(C) Gray.\n(D) Sievert.\n(E) Rad.", 77 | "answer": "(D)", 78 | "source": "frenchmedmcqa" 79 | }, 80 | { 81 | "question": "Parmi les propositions suivantes concernant les phénomènes pouvant entraîner une inhibition de la fluorescence. Laquelle est fausse?", 82 | "options": "(A) La variation du pH peut modifier l'intensité de la fluorescence.\n(B) L'oxygène dissous inhibe la fluorescence.\n(C) Si la température augmente, l'intensité de la fluorescence augmente.\n(D) Le solvant peut inhiber la fluorescence.\n(E) La présence d'impuretés dans la solution peut provoquer une inhibition de la fluorescence.", 83 | "answer": "(C)", 84 | "source": "frenchmedmcqa" 85 | }, 86 | { 87 | "question": "Parmi les propositions suivantes, une seule est exacte, laquelle? La caractérisation de Cryptococcus neoformans dans un LCR est réalisée par la présence à l'examen direct", 88 | "options": "(A) De blastospores.\n(B) D'une capsule polysaccharidique.\n(C) D'arthrospores.\n(D) De vrai mycélium.\n(E) De pseudomycelium.", 89 | "answer": "(B)", 90 | "source": "frenchmedmcqa" 91 | }, 92 | { 93 | "question": "Parmi les propositions suivantes, indiquer celle qui est juste. Chez un sujet intoxiqué par un agent méthémoglobinisant, le traitement par le bleu de méthylène IV sera inefficace si le malade:", 94 | "options": "(A) Est à jeun.\n(B) Est déficitaire en glucose-6-phosphate déshydrogénase.\n(C) Est dans le coma.\n(D) Est déficitaire en cytochrome c réductase.\n(E) A absorbé des psychotropes.", 95 | "answer": "(B)", 96 | "source": "frenchmedmcqa" 97 | } 98 | ] -------------------------------------------------------------------------------- /metadata/merge_json_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def merge_json_files(root_path): 5 | merged_data = {} 6 | for file_name in os.listdir(root_path): 7 | 8 | if file_name.endswith(".json"): 9 | file_path = os.path.join(root_path, file_name) 10 | print(file_path) 11 | # Read the content of the JSON file 12 | with open(file_path, "r", encoding="utf-8") as json_file: 13 | json_data = json.load(json_file) 14 | 15 | # Add the data to the merged dictionary using the file name as the key 16 | merged_data[file_name[:-5]] = json_data 17 | 18 | # Path for the merged file 19 | merged_file_path = "./mixTrain.json" 20 | 21 | # Write the merged JSON data to a new file 22 | with open(merged_file_path, "w", encoding="utf-8") as merged_file: 23 | json.dump(merged_data, merged_file, ensure_ascii=False, indent=2) 24 | 25 | print(f"Merged file has been saved as: {merged_file_path}") 26 | 27 | # Replace with the actual directory path 28 | merge_json_files("./train/mixTrain") 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llama-index-llms-vllm -------------------------------------------------------------------------------- /scripts/3.multinode_train_gema7B_rank0.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python *.py 3 | 4 | node_rank=0 5 | master_ip=yourip 6 | 7 | 8 | cd ./src/sft 9 | process_port=29503 10 | experiment_name=Gemma7B_MixTrain 11 | model_dir=/your/path/to/Gemma-7b 12 | # ckpt_dir=./ckpts/ 13 | train_data_file=./data/gemma/MixTrain 14 | dev_data_file=./data/gemma/dev.json 15 | output_dir=./ckpts 16 | log_folder="./logs/${experiment_name}" 17 | mkdir -p $log_folder 18 | log_name=$(date +"%m-%d_%H-%M").log 19 | 20 | 21 | 22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \ 23 | --config_file ./src/sft/training_config/zero_multi.yaml \ 24 | --num_processes 24 \ 25 | --num_machines 3 \ 26 | --machine_rank ${node_rank} \ 27 | --main_process_ip "${master_ip}" \ 28 | --main_process_port ${process_port} \ 29 | --num_cpu_threads_per_process 8 \ 30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \ 31 | --model_path ${model_dir} \ 32 | --experiment_name ${experiment_name} \ 33 | --gradient_accumulation_steps 8 \ 34 | --train_data_dir ${train_data_file} \ 35 | --dev_data_dir ${dev_data_file} \ 36 | --output_dir ${output_dir} \ 37 | --log_dir ./wandb_logs \ 38 | --n_epochs 1 \ 39 | --train_bsz_per_gpu 2 \ 40 | --eval_bsz_per_gpu 2 \ 41 | --learning_rate 1e-5 \ 42 | --eval_step -1 \ 43 | --save_step -1 \ 44 | --warmup_rates 0.03 \ 45 | --max_ckpts 3 \ 46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 & -------------------------------------------------------------------------------- /scripts/3.multinode_train_gema7B_rank1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python *.py 3 | 4 | node_rank=1 5 | master_ip=yourip 6 | 7 | 8 | cd ./src/sft 9 | process_port=29503 10 | experiment_name=Gemma7B_MixTrain 11 | model_dir=/your/path/to/Gemma-7b 12 | # ckpt_dir=./ckpts/ 13 | train_data_file=./data/gemma/MixTrain 14 | dev_data_file=./data/gemma/dev.json 15 | output_dir=./ckpts 16 | log_folder="./logs/${experiment_name}" 17 | mkdir -p $log_folder 18 | log_name=$(date +"%m-%d_%H-%M").log 19 | 20 | 21 | 22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \ 23 | --config_file ./src/sft/training_config/zero_multi.yaml \ 24 | --num_processes 24 \ 25 | --num_machines 3 \ 26 | --machine_rank ${node_rank} \ 27 | --main_process_ip "${master_ip}" \ 28 | --main_process_port ${process_port} \ 29 | --num_cpu_threads_per_process 8 \ 30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \ 31 | --model_path ${model_dir} \ 32 | --experiment_name ${experiment_name} \ 33 | --gradient_accumulation_steps 8 \ 34 | --train_data_dir ${train_data_file} \ 35 | --dev_data_dir ${dev_data_file} \ 36 | --output_dir ${output_dir} \ 37 | --log_dir ./wandb_logs \ 38 | --n_epochs 1 \ 39 | --train_bsz_per_gpu 2 \ 40 | --eval_bsz_per_gpu 2 \ 41 | --learning_rate 1e-5 \ 42 | --eval_step -1 \ 43 | --save_step -1 \ 44 | --warmup_rates 0.03 \ 45 | --max_ckpts 3 \ 46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 & -------------------------------------------------------------------------------- /scripts/3.multinode_train_gema7B_rank2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python *.py 3 | 4 | node_rank=2 5 | master_ip=yourip 6 | 7 | 8 | cd ./src/sft 9 | process_port=29503 10 | experiment_name=Gemma7B_MixTrain 11 | model_dir=/your/path/to/Gemma-7b 12 | # ckpt_dir=./ckpts/ 13 | train_data_file=./data/gemma/MixTrain 14 | dev_data_file=./data/gemma/dev.json 15 | output_dir=./ckpts 16 | log_folder="./logs/${experiment_name}" 17 | mkdir -p $log_folder 18 | log_name=$(date +"%m-%d_%H-%M").log 19 | 20 | 21 | 22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \ 23 | --config_file ./src/sft/training_config/zero_multi.yaml \ 24 | --num_processes 24 \ 25 | --num_machines 3 \ 26 | --machine_rank ${node_rank} \ 27 | --main_process_ip "${master_ip}" \ 28 | --main_process_port ${process_port} \ 29 | --num_cpu_threads_per_process 8 \ 30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \ 31 | --model_path ${model_dir} \ 32 | --experiment_name ${experiment_name} \ 33 | --gradient_accumulation_steps 8 \ 34 | --train_data_dir ${train_data_file} \ 35 | --dev_data_dir ${dev_data_file} \ 36 | --output_dir ${output_dir} \ 37 | --log_dir ./wandb_logs \ 38 | --n_epochs 1 \ 39 | --train_bsz_per_gpu 2 \ 40 | --eval_bsz_per_gpu 2 \ 41 | --learning_rate 1e-5 \ 42 | --eval_step -1 \ 43 | --save_step -1 \ 44 | --warmup_rates 0.03 \ 45 | --max_ckpts 3 \ 46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 & -------------------------------------------------------------------------------- /src/evaluate/cli_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import torch 4 | from threading import Thread 5 | from transformers import AutoTokenizer 6 | from transformers import AutoModelForCausalLM 7 | import argparse 8 | from transformers import TextIteratorStreamer 9 | 10 | def load_model(model_name): 11 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left', pad_token='<|extra_0|>', eos_token='<|endoftext|>') 12 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype='auto', trust_remote_code=True) 13 | return model, tokenizer 14 | 15 | def generate_prompt(query, history): 16 | if not history: 17 | return f"User:{query}\nAssistant:" 18 | else: 19 | prompt = '' 20 | for i, (old_query, response) in enumerate(history): 21 | prompt += "User:{}\nAssistant:{}\n".format(old_query, response) 22 | prompt += "User:{}\nAssistant:".format(query) 23 | return prompt 24 | 25 | def remove_overlap(str1, str2): 26 | for i in range(len(str1), -1, -1): 27 | if str1.endswith(str2[:i]): 28 | return str2[i:] 29 | return str2 30 | 31 | def main(args): 32 | model, tokenizer = load_model(args.model_name) 33 | sep = tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id) 34 | print(sep) 35 | 36 | model = model.eval() 37 | 38 | gen_kwargs = {'max_new_tokens': 1024, 'do_sample':True, 'top_p':0.7, 'temperature':0.3, 'repetition_penalty':1.1} 39 | 40 | os_name = platform.system() 41 | clear_command = 'cls' if os_name == 'Windows' else 'clear' 42 | history = [] 43 | print("Model: Hello, I am a large model that answers medical and health questions. It is currently in the testing stage. Please follow your doctor's advice. How can I help you? Enter clear to clear the conversation history, stop to terminate the program") 44 | while True: 45 | query = input("\nUser:") 46 | if query == "stop": 47 | break 48 | if query == "clear": 49 | history = [] 50 | os.system(clear_command) 51 | continue 52 | 53 | print(f"Model:", end="", flush=True) 54 | 55 | 56 | prompt = generate_prompt(query, history) 57 | inputs = tokenizer([prompt], return_tensors="pt") 58 | inputs = inputs.to(model.device) 59 | 60 | streamer = TextIteratorStreamer(tokenizer,skip_prompt=True) 61 | generation_kwargs = dict(input_ids=inputs['input_ids'], streamer=streamer, **gen_kwargs) 62 | 63 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 64 | thread.start() 65 | 66 | generated_text = '' 67 | 68 | for new_text in streamer: 69 | if sep in new_text: 70 | new_text = remove_overlap(generated_text,new_text[:-len(sep)]) 71 | for char in new_text: 72 | generated_text += char 73 | print(char,end='',flush = True) 74 | break 75 | for char in new_text: 76 | generated_text += char 77 | print(char,end='',flush = True) 78 | history = history + [(query, generated_text)] 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--model_name", type=str, default="./ckpts/your/path/tfmr") 84 | args = parser.parse_args() 85 | main(args) 86 | -------------------------------------------------------------------------------- /src/evaluate/eval_72b_34b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 4 | import re 5 | import argparse 6 | import json 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | from collections import defaultdict 11 | from llama_index.llms.vllm import Vllm 12 | 13 | 14 | def get_answer(data:str,llm): 15 | response=llm.complete( 16 | data 17 | ) 18 | return response.text 19 | 20 | 21 | def extract_and_choose_answer(pattern, model_answer): 22 | # if '\n' in model_answer: 23 | # model_answer_split = model_answer.split('\n') 24 | # for model_answer_i in model_answer_split: 25 | # if len(model_answer_i): 26 | # model_answer = model_answer_i 27 | # break 28 | matches = re.findall(pattern, model_answer) 29 | option_count = {} 30 | for match in matches: 31 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 32 | 33 | if not option_count: 34 | # else use loose pattern 35 | loose_pattern = r'[A-F]' 36 | if pattern == loose_pattern: 37 | if model_answer == 'Yes.': 38 | return 'A' 39 | elif model_answer == 'No.': 40 | return 'B' 41 | else: 42 | return None 43 | else: 44 | return extract_and_choose_answer(loose_pattern, model_answer) 45 | 46 | max_count = max(option_count.values()) 47 | max_options = [option for option, count in option_count.items() if count == max_count] 48 | return max_options[0] 49 | 50 | 51 | def generate_score(result_path, score_path, wrong_item_path): 52 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 53 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 54 | 55 | all = defaultdict(int) 56 | right = defaultdict(int) 57 | accuracy_dict = defaultdict(int) 58 | wrong_item = [] 59 | 60 | print(f'****Total:{len(json_objects)}****') 61 | debug = True 62 | for item in json_objects: 63 | source = item["source"] 64 | for answer in item["model_answer"]: 65 | all[source] += 1 66 | pattern = r'(\(]([A-Fa-f])[)\)' 67 | extract_answer = extract_and_choose_answer(pattern, answer) 68 | item['extract_answer'] = extract_answer 69 | if debug: 70 | debug = False 71 | print(f'extract_answer:{extract_answer}') 72 | right_answer = item['answer'] 73 | print(f'right_answer:{right_answer}') 74 | if item['answer'] == extract_answer: 75 | right[source] += 1 76 | else: 77 | wrong_item.append(item) 78 | 79 | 80 | print(f'all:{all}') 81 | print(f'right:{right}') 82 | 83 | for key in right: 84 | accuracy_dict[key] = right[key] / all[key] 85 | 86 | with open(score_path, "w", encoding="utf8") as f: 87 | json.dump(accuracy_dict, f, indent=4) 88 | 89 | print(f'***********score_result save in {score_path}*************') 90 | 91 | with open(wrong_item_path, "w", encoding="utf8") as f: 92 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 93 | 94 | print(f'***********wrong_item save in {wrong_item_path}*************') 95 | 96 | 97 | def generate_response(args,llm): 98 | 99 | 100 | # model_path = args.model_path 101 | 102 | fp = open(args.output_path,'w') 103 | 104 | with open(args.input_path) as f: 105 | data = json.load(f) 106 | 107 | for item in tqdm(data): 108 | question=item['question'] 109 | answer=get_answer(question,llm) 110 | item['model_answer']=answer 111 | fp.write(json.dumps(item, ensure_ascii=False) +'\n') 112 | fp.flush() 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 118 | parser.add_argument("--input_path", type=str, help="path to the input data") 119 | parser.add_argument("--output_path", type=str, help="path to the output data") 120 | parser.add_argument("--score_path", type=str, help="path to the score") 121 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 122 | args = parser.parse_args() 123 | llm = Vllm( 124 | model=args.model_path, 125 | trust_remote_code=True, 126 | max_new_tokens=64, 127 | temperature=0, 128 | dtype="bfloat16", 129 | tensor_parallel_size=8, 130 | vllm_kwargs={"swap_space": 1}, 131 | ) 132 | generate_response(args,llm) 133 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 134 | 135 | 136 | ''' 137 | 138 | accelerate launch /mntcephfs/data/med/xidong/Medbase/src/evaluate/eval_qwen.py \ 139 | --model_path=/mntcephfs/data/med/xidong/checkpoints/Qwen-1_8B \ 140 | --input_path=/mntcephfs/data/med/xidong/Medbase/data/Qwen-1.8B/test.json \ 141 | --output_path=/mntcephfs/data/med/xidong/Medbase/result/Qwen-1.8B/model_ans.jsonl \ 142 | --score_path=/mntcephfs/data/med/xidong/Medbase/result/Qwen-1.8B/score.json \ 143 | --batch_size=8 > ${log_folder}/$log_name 2>&1 & 144 | 145 | ''' 146 | 147 | 148 | -------------------------------------------------------------------------------- /src/evaluate/eval_gemma.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | 103 | print(f'****Total:{len(json_objects)}****') 104 | debug = True 105 | for item in json_objects: 106 | source = item["source"] 107 | for answer in item["model_answer"]: 108 | all[source] += 1 109 | pattern = r'[(\(]([A-Fa-f])[)\)]' 110 | extract_answer = extract_and_choose_answer(pattern, answer) 111 | if debug: 112 | debug = False 113 | print(f'extract_answer:{extract_answer}') 114 | right_answer = item['answer'] 115 | print(f'right_answer:{right_answer}') 116 | if item['answer'] == extract_answer: 117 | right[source] += 1 118 | 119 | 120 | print(f'all:{all}') 121 | print(f'right:{right}') 122 | 123 | for key in right: 124 | accuracy_dict[key] = right[key] / all[key] 125 | 126 | with open(score_path, "w", encoding="utf8") as f: 127 | json.dump(accuracy_dict, f, indent=4) 128 | 129 | print(f'***********score_result save in {score_path}*************') 130 | 131 | 132 | def generate_response(args): 133 | accelerator = Accelerator() 134 | 135 | model_path = args.model_path 136 | accelerator.print(f'****************model_path:{model_path}******************') 137 | 138 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left', pad_token='', eos_token='') 139 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, ignore_mismatched_sizes=True).half() 140 | # generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=tokenizer.pad_token_id, num_return_sequences=args.num_return, max_new_tokens=256, min_new_tokens=2, do_sample=False, temperature=1.0, top_k=50, top_p=1.0) 141 | model = model.half().cuda() 142 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 143 | 144 | 145 | dataset = TestDataset(args.input_path, tokenizer) 146 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 147 | 148 | model = model.eval() 149 | if dist.is_initialized(): 150 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 151 | 152 | dataloader = accelerator.prepare(dataloader) 153 | accelerator.print(f'******************load_model from {model_path}******************') 154 | 155 | if accelerator.is_main_process: 156 | fp = open(args.output_path,'w') 157 | 158 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 159 | for batch in dataloader_iterator: 160 | batch_input_ids = batch["input_ids"] 161 | batch_data = batch["data"] 162 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 163 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 164 | 165 | 166 | if dist.is_initialized(): 167 | all_batch_data = [None] * dist.get_world_size() 168 | all_batch_responses = [None] * dist.get_world_size() 169 | dist.all_gather_object(all_batch_responses, batch_responses) 170 | dist.all_gather_object(all_batch_data, batch_data) 171 | else: 172 | all_batch_data = [batch_data, ] 173 | all_batch_responses = [batch_responses, ] 174 | 175 | all_data = [item for sublist in all_batch_data for item in sublist] 176 | all_response = [item for sublist in all_batch_responses for item in sublist] 177 | 178 | for data, responses in zip(all_data, all_response): 179 | answer_list = [] 180 | for response in responses: 181 | answer_list.append(response) 182 | data['model_answer'] = answer_list 183 | if accelerator.is_main_process: 184 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 185 | fp.flush() 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 190 | parser.add_argument("--input_path", type=str, help="path to the input data") 191 | parser.add_argument("--output_path", type=str, help="path to the output data") 192 | parser.add_argument("--score_path", type=str, help="path to the score") 193 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 194 | parser.add_argument("--num_return", type=int, help="number of return sequences") 195 | parser.add_argument("--batch_size", type=int, help="batch size") 196 | args = parser.parse_args() 197 | generate_response(args) 198 | generate_score(args.output_path, args.score_path) 199 | -------------------------------------------------------------------------------- /src/evaluate/eval_huatuo2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | wrong_item = [] 103 | 104 | print(f'****Total:{len(json_objects)}****') 105 | debug = True 106 | for item in json_objects: 107 | source = item["source"] 108 | for answer in item["model_answer"]: 109 | all[source] += 1 110 | pattern = r'[(\(]([A-Fa-f])[)\)]' 111 | extract_answer = extract_and_choose_answer(pattern, answer) 112 | item['extract_answer'] = extract_answer 113 | if debug: 114 | debug = False 115 | print(f'extract_answer:{extract_answer}') 116 | right_answer = item['answer'] 117 | print(f'right_answer:{right_answer}') 118 | if item['answer'] == extract_answer: 119 | right[source] += 1 120 | else: 121 | wrong_item.append(item) 122 | 123 | 124 | print(f'all:{all}') 125 | print(f'right:{right}') 126 | 127 | for key in right: 128 | accuracy_dict[key] = right[key] / all[key] 129 | 130 | with open(score_path, "w", encoding="utf8") as f: 131 | json.dump(accuracy_dict, f, indent=4) 132 | 133 | print(f'***********score_result save in {score_path}*************') 134 | 135 | with open(wrong_item_path, "w", encoding="utf8") as f: 136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 137 | 138 | print(f'***********wrong_item save in {wrong_item_path}*************') 139 | 140 | 141 | def generate_response(args): 142 | accelerator = Accelerator() 143 | 144 | model_path = args.model_path 145 | accelerator.print(f'****************model_path:{model_path}******************') 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',trust_remote_code=True) 148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 149 | model = model.half().cuda() 150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 151 | 152 | 153 | dataset = TestDataset(args.input_path, tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 155 | 156 | model = model.eval() 157 | if dist.is_initialized(): 158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 159 | 160 | dataloader = accelerator.prepare(dataloader) 161 | accelerator.print(f'******************load_model from {model_path}******************') 162 | 163 | if accelerator.is_main_process: 164 | fp = open(args.output_path,'w') 165 | 166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 167 | for batch in dataloader_iterator: 168 | batch_input_ids = batch["input_ids"] 169 | batch_data = batch["data"] 170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 172 | 173 | 174 | if dist.is_initialized(): 175 | all_batch_data = [None] * dist.get_world_size() 176 | all_batch_responses = [None] * dist.get_world_size() 177 | dist.all_gather_object(all_batch_responses, batch_responses) 178 | dist.all_gather_object(all_batch_data, batch_data) 179 | else: 180 | all_batch_data = [batch_data, ] 181 | all_batch_responses = [batch_responses, ] 182 | 183 | all_data = [item for sublist in all_batch_data for item in sublist] 184 | all_response = [item for sublist in all_batch_responses for item in sublist] 185 | 186 | for data, responses in zip(all_data, all_response): 187 | answer_list = [] 188 | for response in responses: 189 | answer_list.append(response) 190 | data['model_answer'] = answer_list 191 | if accelerator.is_main_process: 192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 193 | fp.flush() 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 198 | parser.add_argument("--input_path", type=str, help="path to the input data") 199 | parser.add_argument("--output_path", type=str, help="path to the output data") 200 | parser.add_argument("--score_path", type=str, help="path to the score") 201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 202 | parser.add_argument("--num_return", type=int, help="number of return sequences") 203 | parser.add_argument("--batch_size", type=int, help="batch size") 204 | args = parser.parse_args() 205 | # generate_response(args) 206 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 207 | 208 | -------------------------------------------------------------------------------- /src/evaluate/eval_llama2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | wrong_item = [] 103 | 104 | print(f'****Total:{len(json_objects)}****') 105 | debug = True 106 | for item in json_objects: 107 | source = item["source"] 108 | for answer in item["model_answer"]: 109 | all[source] += 1 110 | pattern = r'[(\(]([A-Fa-f])[)\)]' 111 | extract_answer = extract_and_choose_answer(pattern, answer) 112 | item['extract_answer'] = extract_answer 113 | if debug: 114 | debug = False 115 | print(f'extract_answer:{extract_answer}') 116 | right_answer = item['answer'] 117 | print(f'right_answer:{right_answer}') 118 | if item['answer'] == extract_answer: 119 | right[source] += 1 120 | else: 121 | wrong_item.append(item) 122 | 123 | 124 | print(f'all:{all}') 125 | print(f'right:{right}') 126 | 127 | for key in right: 128 | accuracy_dict[key] = right[key] / all[key] 129 | 130 | with open(score_path, "w", encoding="utf8") as f: 131 | json.dump(accuracy_dict, f, indent=4) 132 | 133 | print(f'***********score_result save in {score_path}*************') 134 | 135 | with open(wrong_item_path, "w", encoding="utf8") as f: 136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 137 | 138 | print(f'***********wrong_item save in {wrong_item_path}*************') 139 | 140 | 141 | def generate_response(args): 142 | accelerator = Accelerator() 143 | 144 | model_path = args.model_path 145 | accelerator.print(f'****************model_path:{model_path}******************') 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='') 148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 149 | model = model.half().cuda() 150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 151 | 152 | 153 | dataset = TestDataset(args.input_path, tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 155 | 156 | model = model.eval() 157 | if dist.is_initialized(): 158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 159 | 160 | dataloader = accelerator.prepare(dataloader) 161 | accelerator.print(f'******************load_model from {model_path}******************') 162 | 163 | if accelerator.is_main_process: 164 | fp = open(args.output_path,'w') 165 | 166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 167 | for batch in dataloader_iterator: 168 | batch_input_ids = batch["input_ids"] 169 | batch_data = batch["data"] 170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 172 | 173 | 174 | if dist.is_initialized(): 175 | all_batch_data = [None] * dist.get_world_size() 176 | all_batch_responses = [None] * dist.get_world_size() 177 | dist.all_gather_object(all_batch_responses, batch_responses) 178 | dist.all_gather_object(all_batch_data, batch_data) 179 | else: 180 | all_batch_data = [batch_data, ] 181 | all_batch_responses = [batch_responses, ] 182 | 183 | all_data = [item for sublist in all_batch_data for item in sublist] 184 | all_response = [item for sublist in all_batch_responses for item in sublist] 185 | 186 | for data, responses in zip(all_data, all_response): 187 | answer_list = [] 188 | for response in responses: 189 | answer_list.append(response) 190 | data['model_answer'] = answer_list 191 | if accelerator.is_main_process: 192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 193 | fp.flush() 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 198 | parser.add_argument("--input_path", type=str, help="path to the input data") 199 | parser.add_argument("--output_path", type=str, help="path to the output data") 200 | parser.add_argument("--score_path", type=str, help="path to the score") 201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 202 | parser.add_argument("--num_return", type=int, help="number of return sequences") 203 | parser.add_argument("--batch_size", type=int, help="batch size") 204 | args = parser.parse_args() 205 | generate_response(args) 206 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 207 | 208 | -------------------------------------------------------------------------------- /src/evaluate/eval_llama70b.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | import json 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | import torch.distributed as dist 11 | from collections import defaultdict 12 | from llama_index.llms.vllm import Vllm 13 | 14 | llm = Vllm( 15 | model='./Llama-2-70b-hf/', 16 | trust_remote_code=True, 17 | max_new_tokens=64, 18 | temperature=0, 19 | dtype="bfloat16", 20 | tensor_parallel_size=8, 21 | vllm_kwargs={"swap_space": 1}, 22 | ) 23 | 24 | def get_answer(data:str): 25 | response=llm.complete( 26 | data 27 | ) 28 | return response.text 29 | 30 | 31 | def extract_and_choose_answer(pattern, model_answer): 32 | if '\n' in model_answer: 33 | model_answer_split = model_answer.split('\n') 34 | for model_answer_i in model_answer_split: 35 | if len(model_answer_i): 36 | model_answer = model_answer_i 37 | break 38 | matches = re.findall(pattern, model_answer) 39 | option_count = {} 40 | for match in matches: 41 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 42 | 43 | if not option_count: 44 | # else use loose pattern 45 | loose_pattern = r'[A-F]' 46 | if pattern == loose_pattern: 47 | if model_answer == 'Yes.': 48 | return 'A' 49 | elif model_answer == 'No.': 50 | return 'B' 51 | else: 52 | return None 53 | else: 54 | return extract_and_choose_answer(loose_pattern, model_answer) 55 | 56 | max_count = max(option_count.values()) 57 | max_options = [option for option, count in option_count.items() if count == max_count] 58 | return max_options[0] 59 | 60 | 61 | 62 | def generate_score(result_path, score_path, wrong_item_path): 63 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 64 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 65 | 66 | all = defaultdict(int) 67 | right = defaultdict(int) 68 | accuracy_dict = defaultdict(int) 69 | wrong_item = [] 70 | 71 | print(f'****Total:{len(json_objects)}****') 72 | debug = True 73 | for item in json_objects: 74 | source = item["source"] 75 | for answer in item["model_answer"]: 76 | all[source] += 1 77 | pattern = r'(\(]([A-Fa-f])[)\)' 78 | extract_answer = extract_and_choose_answer(pattern, answer) 79 | item['extract_answer'] = extract_answer 80 | if debug: 81 | debug = False 82 | print(f'extract_answer:{extract_answer}') 83 | right_answer = item['answer'] 84 | print(f'right_answer:{right_answer}') 85 | if item['answer'] == extract_answer: 86 | right[source] += 1 87 | else: 88 | wrong_item.append(item) 89 | 90 | 91 | print(f'all:{all}') 92 | print(f'right:{right}') 93 | 94 | for key in right: 95 | accuracy_dict[key] = right[key] / all[key] 96 | 97 | with open(score_path, "w", encoding="utf8") as f: 98 | json.dump(accuracy_dict, f, indent=4) 99 | 100 | print(f'***********score_result save in {score_path}*************') 101 | 102 | with open(wrong_item_path, "w", encoding="utf8") as f: 103 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 104 | 105 | print(f'***********wrong_item save in {wrong_item_path}*************') 106 | 107 | 108 | def generate_response(args): 109 | 110 | 111 | model_path = args.model_path 112 | 113 | fp = open(args.output_path,'w') 114 | 115 | with open(args.input_path) as f: 116 | data = json.load(f) 117 | 118 | for item in tqdm(data): 119 | question=item['question'] 120 | answer=get_answer(question) 121 | item['model_answer']=answer 122 | fp.write(json.dumps(item, ensure_ascii=False) +'\n') 123 | fp.flush() 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 129 | parser.add_argument("--input_path", type=str, help="path to the input data") 130 | parser.add_argument("--output_path", type=str, help="path to the output data") 131 | parser.add_argument("--score_path", type=str, help="path to the score") 132 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 133 | args = parser.parse_args() 134 | generate_response(args) 135 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 136 | 137 | 138 | -------------------------------------------------------------------------------- /src/evaluate/eval_meditron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 4 | import re 5 | import argparse 6 | from accelerate import Accelerator 7 | import json 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | import torch.distributed as dist 11 | from collections import defaultdict 12 | 13 | 14 | class TestDataset(torch.utils.data.Dataset): 15 | def __init__(self, data_path,tokenizer): 16 | self.data = [] 17 | with open(data_path) as f: 18 | self.data = json.load(f) 19 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 20 | if dist_flag_0: 21 | print(f'load {len(self.data)} data from {data_path}') 22 | self.tokenizer = tokenizer 23 | self.debug = True 24 | 25 | def __getitem__(self, index): 26 | item = self.data[index] 27 | return { 28 | 'data': item, 29 | 'input': item['question'] 30 | } 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def collate_fn(self, batch): 36 | batch_query = [x['input'] for x in batch] 37 | batch_data = [x['data'] for x in batch] 38 | out_batch = {} 39 | out_batch['data'] = batch_data 40 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 41 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 42 | if self.debug and dist_flag_0: 43 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 44 | for idx, sample in enumerate(decoded_texts): 45 | print(f'*******************batch_texts[{idx}]**********************************') 46 | print(sample) 47 | self.debug = False 48 | return out_batch 49 | 50 | 51 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 52 | responses_list=[] 53 | batch_return=[] 54 | input_len = len(batch_input_ids[0]) 55 | for idx, output_ids in enumerate(batch_output_ids): 56 | generated_ids = output_ids[input_len:] 57 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 58 | if idx % num_return == num_return-1: 59 | responses_list.append(batch_return) 60 | batch_return=[] 61 | return responses_list 62 | 63 | def extract_and_choose_answer(pattern, model_answer): 64 | if '\n' in model_answer: 65 | model_answer_split = model_answer.split('\n') 66 | for model_answer_i in model_answer_split: 67 | if len(model_answer_i): 68 | model_answer = model_answer_i 69 | break 70 | matches = re.findall(pattern, model_answer) 71 | option_count = {} 72 | for match in matches: 73 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 74 | 75 | if not option_count: 76 | # else use loose pattern 77 | loose_pattern = r'[A-F]' 78 | if pattern == loose_pattern: 79 | if model_answer == 'Yes.': 80 | return 'A' 81 | elif model_answer == 'No.': 82 | return 'B' 83 | else: 84 | return None 85 | else: 86 | return extract_and_choose_answer(loose_pattern, model_answer) 87 | 88 | max_count = max(option_count.values()) 89 | max_options = [option for option, count in option_count.items() if count == max_count] 90 | return max_options[0] 91 | 92 | 93 | 94 | def generate_score(result_path, score_path, wrong_item_path): 95 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 96 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 97 | 98 | all = defaultdict(int) 99 | right = defaultdict(int) 100 | accuracy_dict = defaultdict(int) 101 | wrong_item = [] 102 | 103 | print(f'****Total:{len(json_objects)}****') 104 | debug = True 105 | for item in json_objects: 106 | source = item["source"] 107 | for answer in item["model_answer"]: 108 | all[source] += 1 109 | pattern = r'[(\(]([A-Fa-f])[)\)]' 110 | extract_answer = extract_and_choose_answer(pattern, answer) 111 | item['extract_answer'] = extract_answer 112 | if debug: 113 | debug = False 114 | print(f'extract_answer:{extract_answer}') 115 | right_answer = item['answer'] 116 | print(f'right_answer:{right_answer}') 117 | if item['answer'] == extract_answer: 118 | right[source] += 1 119 | else: 120 | wrong_item.append(item) 121 | 122 | 123 | print(f'all:{all}') 124 | print(f'right:{right}') 125 | 126 | for key in right: 127 | accuracy_dict[key] = right[key] / all[key] 128 | 129 | with open(score_path, "w", encoding="utf8") as f: 130 | json.dump(accuracy_dict, f, indent=4) 131 | 132 | print(f'***********score_result save in {score_path}*************') 133 | 134 | with open(wrong_item_path, "w", encoding="utf8") as f: 135 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 136 | 137 | print(f'***********wrong_item save in {wrong_item_path}*************') 138 | 139 | 140 | def generate_response(args): 141 | accelerator = Accelerator() 142 | 143 | model_path = args.model_path 144 | accelerator.print(f'****************model_path:{model_path}******************') 145 | 146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', pad_token="", eos_token="") 147 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 148 | model = model.half().cuda() 149 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 150 | 151 | 152 | dataset = TestDataset(args.input_path, tokenizer) 153 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 154 | 155 | model = model.eval() 156 | if dist.is_initialized(): 157 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 158 | 159 | dataloader = accelerator.prepare(dataloader) 160 | accelerator.print(f'******************load_model from {model_path}******************') 161 | 162 | if accelerator.is_main_process: 163 | fp = open(args.output_path,'w') 164 | 165 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 166 | for batch in dataloader_iterator: 167 | batch_input_ids = batch["input_ids"] 168 | batch_data = batch["data"] 169 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 170 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 171 | 172 | 173 | if dist.is_initialized(): 174 | all_batch_data = [None] * dist.get_world_size() 175 | all_batch_responses = [None] * dist.get_world_size() 176 | dist.all_gather_object(all_batch_responses, batch_responses) 177 | dist.all_gather_object(all_batch_data, batch_data) 178 | else: 179 | all_batch_data = [batch_data, ] 180 | all_batch_responses = [batch_responses, ] 181 | 182 | all_data = [item for sublist in all_batch_data for item in sublist] 183 | all_response = [item for sublist in all_batch_responses for item in sublist] 184 | 185 | for data, responses in zip(all_data, all_response): 186 | answer_list = [] 187 | for response in responses: 188 | answer_list.append(response) 189 | data['model_answer'] = answer_list 190 | if accelerator.is_main_process: 191 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 192 | fp.flush() 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 197 | parser.add_argument("--input_path", type=str, help="path to the input data") 198 | parser.add_argument("--output_path", type=str, help="path to the output data") 199 | parser.add_argument("--score_path", type=str, help="path to the score") 200 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 201 | parser.add_argument("--num_return", type=int, help="number of return sequences") 202 | parser.add_argument("--batch_size", type=int, help="batch size") 203 | args = parser.parse_args() 204 | generate_response(args) 205 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 206 | 207 | 208 | -------------------------------------------------------------------------------- /src/evaluate/eval_meditron70b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 4 | import re 5 | import argparse 6 | import json 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | from collections import defaultdict 11 | from llama_index.llms.vllm import Vllm 12 | 13 | llm = Vllm( 14 | model='./models/meditron-70b/', 15 | trust_remote_code=True, 16 | max_new_tokens=64, 17 | temperature=0, 18 | dtype="bfloat16", 19 | tensor_parallel_size=8, 20 | vllm_kwargs={"swap_space": 1}, 21 | ) 22 | 23 | def get_answer(data:str): 24 | response=llm.complete( 25 | data 26 | ) 27 | return response.text 28 | 29 | 30 | def extract_and_choose_answer(pattern, model_answer): 31 | if '\n' in model_answer: 32 | model_answer_split = model_answer.split('\n') 33 | for model_answer_i in model_answer_split: 34 | if len(model_answer_i): 35 | model_answer = model_answer_i 36 | break 37 | matches = re.findall(pattern, model_answer) 38 | option_count = {} 39 | for match in matches: 40 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 41 | 42 | if not option_count: 43 | # else use loose pattern 44 | loose_pattern = r'[A-F]' 45 | if pattern == loose_pattern: 46 | if model_answer == 'Yes.': 47 | return 'A' 48 | elif model_answer == 'No.': 49 | return 'B' 50 | else: 51 | return None 52 | else: 53 | return extract_and_choose_answer(loose_pattern, model_answer) 54 | 55 | max_count = max(option_count.values()) 56 | max_options = [option for option, count in option_count.items() if count == max_count] 57 | return max_options[0] 58 | 59 | 60 | def generate_score(result_path, score_path, wrong_item_path): 61 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 62 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 63 | 64 | all = defaultdict(int) 65 | right = defaultdict(int) 66 | accuracy_dict = defaultdict(int) 67 | wrong_item = [] 68 | 69 | print(f'****Total:{len(json_objects)}****') 70 | debug = True 71 | for item in json_objects: 72 | source = item["source"] 73 | for answer in item["model_answer"]: 74 | all[source] += 1 75 | pattern = r'(\(]([A-Fa-f])[)\)' 76 | extract_answer = extract_and_choose_answer(pattern, answer) 77 | item['extract_answer'] = extract_answer 78 | if debug: 79 | debug = False 80 | print(f'extract_answer:{extract_answer}') 81 | right_answer = item['answer'] 82 | print(f'right_answer:{right_answer}') 83 | if item['answer'] == extract_answer: 84 | right[source] += 1 85 | else: 86 | wrong_item.append(item) 87 | 88 | 89 | print(f'all:{all}') 90 | print(f'right:{right}') 91 | 92 | for key in right: 93 | accuracy_dict[key] = right[key] / all[key] 94 | 95 | with open(score_path, "w", encoding="utf8") as f: 96 | json.dump(accuracy_dict, f, indent=4) 97 | 98 | print(f'***********score_result save in {score_path}*************') 99 | 100 | with open(wrong_item_path, "w", encoding="utf8") as f: 101 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 102 | 103 | print(f'***********wrong_item save in {wrong_item_path}*************') 104 | 105 | 106 | def generate_response(args): 107 | 108 | 109 | # model_path = args.model_path 110 | 111 | fp = open(args.output_path,'w') 112 | 113 | with open(args.input_path) as f: 114 | data = json.load(f) 115 | 116 | for item in tqdm(data): 117 | question=item['question'] 118 | answer=get_answer(question) 119 | item['model_answer']=answer 120 | fp.write(json.dumps(item, ensure_ascii=False) +'\n') 121 | fp.flush() 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | # parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 127 | parser.add_argument("--input_path", type=str, help="path to the input data") 128 | parser.add_argument("--output_path", type=str, help="path to the output data") 129 | parser.add_argument("--score_path", type=str, help="path to the score") 130 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 131 | args = parser.parse_args() 132 | generate_response(args) 133 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /src/evaluate/eval_mistral.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | wrong_item = [] 103 | 104 | print(f'****Total:{len(json_objects)}****') 105 | debug = True 106 | for item in json_objects: 107 | source = item["source"] 108 | for answer in item["model_answer"]: 109 | all[source] += 1 110 | pattern = r'[(\(]([A-Fa-f])[)\)]' 111 | extract_answer = extract_and_choose_answer(pattern, answer) 112 | item['extract_answer'] = extract_answer 113 | if debug: 114 | debug = False 115 | print(f'extract_answer:{extract_answer}') 116 | right_answer = item['answer'] 117 | print(f'right_answer:{right_answer}') 118 | if item['answer'] == extract_answer: 119 | right[source] += 1 120 | else: 121 | wrong_item.append(item) 122 | 123 | 124 | print(f'all:{all}') 125 | print(f'right:{right}') 126 | 127 | for key in right: 128 | accuracy_dict[key] = right[key] / all[key] 129 | 130 | with open(score_path, "w", encoding="utf8") as f: 131 | json.dump(accuracy_dict, f, indent=4) 132 | 133 | print(f'***********score_result save in {score_path}*************') 134 | 135 | with open(wrong_item_path, "w", encoding="utf8") as f: 136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 137 | 138 | print(f'***********wrong_item save in {wrong_item_path}*************') 139 | 140 | 141 | def generate_response(args): 142 | accelerator = Accelerator() 143 | 144 | model_path = args.model_path 145 | accelerator.print(f'****************model_path:{model_path}******************') 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='null') 148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 149 | model = model.half().cuda() 150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 151 | 152 | 153 | dataset = TestDataset(args.input_path, tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 155 | 156 | model = model.eval() 157 | if dist.is_initialized(): 158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 159 | 160 | dataloader = accelerator.prepare(dataloader) 161 | accelerator.print(f'******************load_model from {model_path}******************') 162 | 163 | if accelerator.is_main_process: 164 | fp = open(args.output_path,'w') 165 | 166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 167 | for batch in dataloader_iterator: 168 | batch_input_ids = batch["input_ids"] 169 | batch_data = batch["data"] 170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 172 | 173 | 174 | if dist.is_initialized(): 175 | all_batch_data = [None] * dist.get_world_size() 176 | all_batch_responses = [None] * dist.get_world_size() 177 | dist.all_gather_object(all_batch_responses, batch_responses) 178 | dist.all_gather_object(all_batch_data, batch_data) 179 | else: 180 | all_batch_data = [batch_data, ] 181 | all_batch_responses = [batch_responses, ] 182 | 183 | all_data = [item for sublist in all_batch_data for item in sublist] 184 | all_response = [item for sublist in all_batch_responses for item in sublist] 185 | 186 | for data, responses in zip(all_data, all_response): 187 | answer_list = [] 188 | for response in responses: 189 | answer_list.append(response) 190 | data['model_answer'] = answer_list 191 | if accelerator.is_main_process: 192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 193 | fp.flush() 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 198 | parser.add_argument("--input_path", type=str, help="path to the input data") 199 | parser.add_argument("--output_path", type=str, help="path to the output data") 200 | parser.add_argument("--score_path", type=str, help="path to the score") 201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 202 | parser.add_argument("--num_return", type=int, help="number of return sequences") 203 | parser.add_argument("--batch_size", type=int, help="batch size") 204 | args = parser.parse_args() 205 | generate_response(args) 206 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 207 | 208 | 209 | -------------------------------------------------------------------------------- /src/evaluate/eval_mmedlm2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | wrong_item = [] 103 | 104 | print(f'****Total:{len(json_objects)}****') 105 | debug = True 106 | for item in json_objects: 107 | source = item["source"] 108 | for answer in item["model_answer"]: 109 | all[source] += 1 110 | pattern = r'[(\(]([A-Fa-f])[)\)]' 111 | extract_answer = extract_and_choose_answer(pattern, answer) 112 | item['extract_answer'] = extract_answer 113 | if debug: 114 | debug = False 115 | print(f'extract_answer:{extract_answer}') 116 | right_answer = item['answer'] 117 | print(f'right_answer:{right_answer}') 118 | if item['answer'] == extract_answer: 119 | right[source] += 1 120 | else: 121 | wrong_item.append(item) 122 | 123 | 124 | print(f'all:{all}') 125 | print(f'right:{right}') 126 | 127 | for key in right: 128 | accuracy_dict[key] = right[key] / all[key] 129 | 130 | with open(score_path, "w", encoding="utf8") as f: 131 | json.dump(accuracy_dict, f, indent=4) 132 | 133 | print(f'***********score_result save in {score_path}*************') 134 | 135 | with open(wrong_item_path, "w", encoding="utf8") as f: 136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 137 | 138 | print(f'***********wrong_item save in {wrong_item_path}*************') 139 | 140 | 141 | def generate_response(args): 142 | accelerator = Accelerator() 143 | 144 | model_path = args.model_path 145 | accelerator.print(f'****************model_path:{model_path}******************') 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='',trust_remote_code=True) 148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 149 | model = model.half().cuda() 150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 151 | 152 | 153 | dataset = TestDataset(args.input_path, tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 155 | 156 | model = model.eval() 157 | if dist.is_initialized(): 158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 159 | 160 | dataloader = accelerator.prepare(dataloader) 161 | accelerator.print(f'******************load_model from {model_path}******************') 162 | 163 | if accelerator.is_main_process: 164 | fp = open(args.output_path,'w') 165 | 166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 167 | for batch in dataloader_iterator: 168 | batch_input_ids = batch["input_ids"] 169 | batch_data = batch["data"] 170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 172 | 173 | 174 | if dist.is_initialized(): 175 | all_batch_data = [None] * dist.get_world_size() 176 | all_batch_responses = [None] * dist.get_world_size() 177 | dist.all_gather_object(all_batch_responses, batch_responses) 178 | dist.all_gather_object(all_batch_data, batch_data) 179 | else: 180 | all_batch_data = [batch_data, ] 181 | all_batch_responses = [batch_responses, ] 182 | 183 | all_data = [item for sublist in all_batch_data for item in sublist] 184 | all_response = [item for sublist in all_batch_responses for item in sublist] 185 | 186 | for data, responses in zip(all_data, all_response): 187 | answer_list = [] 188 | for response in responses: 189 | answer_list.append(response) 190 | data['model_answer'] = answer_list 191 | if accelerator.is_main_process: 192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 193 | fp.flush() 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 198 | parser.add_argument("--input_path", type=str, help="path to the input data") 199 | parser.add_argument("--output_path", type=str, help="path to the output data") 200 | parser.add_argument("--score_path", type=str, help="path to the score") 201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 202 | parser.add_argument("--num_return", type=int, help="number of return sequences") 203 | parser.add_argument("--batch_size", type=int, help="batch size") 204 | args = parser.parse_args() 205 | generate_response(args) 206 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 207 | 208 | 209 | -------------------------------------------------------------------------------- /src/evaluate/eval_qwen.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | 103 | print(f'****Total:{len(json_objects)}****') 104 | debug = True 105 | for item in json_objects: 106 | source = item["source"] 107 | for answer in item["model_answer"]: 108 | all[source] += 1 109 | pattern = r'[(\(]([A-Fa-f])[)\)]' 110 | extract_answer = extract_and_choose_answer(pattern, answer) 111 | if debug: 112 | debug = False 113 | print(f'extract_answer:{extract_answer}') 114 | right_answer = item['answer'] 115 | print(f'right_answer:{right_answer}') 116 | if item['answer'] == extract_answer: 117 | right[source] += 1 118 | 119 | 120 | print(f'all:{all}') 121 | print(f'right:{right}') 122 | 123 | for key in right: 124 | accuracy_dict[key] = right[key] / all[key] 125 | 126 | with open(score_path, "w", encoding="utf8") as f: 127 | json.dump(accuracy_dict, f, indent=4) 128 | 129 | print(f'***********score_result save in {score_path}*************') 130 | 131 | 132 | def generate_response(args): 133 | accelerator = Accelerator() 134 | 135 | model_path = args.model_path 136 | accelerator.print(f'****************model_path:{model_path}******************') 137 | 138 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left', pad_token='<|extra_0|>', eos_token='<|endoftext|>') 139 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, ignore_mismatched_sizes=True).half() 140 | generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=tokenizer.pad_token_id, num_return_sequences=args.num_return, max_new_tokens=256, min_new_tokens=2, do_sample=False, temperature=1.0, top_k=50, top_p=1.0) 141 | 142 | dataset = TestDataset(args.input_path, tokenizer) 143 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 144 | 145 | model = model.eval() 146 | if dist.is_initialized(): 147 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 148 | 149 | dataloader = accelerator.prepare(dataloader) 150 | accelerator.print(f'******************load_model from {model_path}******************') 151 | 152 | if accelerator.is_main_process: 153 | fp = open(args.output_path,'w') 154 | 155 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 156 | for batch in dataloader_iterator: 157 | batch_input_ids = batch["input_ids"] 158 | batch_data = batch["data"] 159 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, generation_config=generation_config) 160 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 161 | 162 | 163 | if dist.is_initialized(): 164 | all_batch_data = [None] * dist.get_world_size() 165 | all_batch_responses = [None] * dist.get_world_size() 166 | dist.all_gather_object(all_batch_responses, batch_responses) 167 | dist.all_gather_object(all_batch_data, batch_data) 168 | else: 169 | all_batch_data = [batch_data, ] 170 | all_batch_responses = [batch_responses, ] 171 | 172 | all_data = [item for sublist in all_batch_data for item in sublist] 173 | all_response = [item for sublist in all_batch_responses for item in sublist] 174 | 175 | for data, responses in zip(all_data, all_response): 176 | answer_list = [] 177 | for response in responses: 178 | answer_list.append(response) 179 | data['model_answer'] = answer_list 180 | if accelerator.is_main_process: 181 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 182 | fp.flush() 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 187 | parser.add_argument("--input_path", type=str, help="path to the input data") 188 | parser.add_argument("--output_path", type=str, help="path to the output data") 189 | parser.add_argument("--score_path", type=str, help="path to the score") 190 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 191 | parser.add_argument("--num_return", type=int, help="number of return sequences") 192 | parser.add_argument("--batch_size", type=int, help="batch size") 193 | args = parser.parse_args() 194 | generate_response(args) 195 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 196 | -------------------------------------------------------------------------------- /src/evaluate/eval_yi.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | matches = re.findall(pattern, model_answer) 72 | option_count = {} 73 | for match in matches: 74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 75 | 76 | if not option_count: 77 | # else use loose pattern 78 | loose_pattern = r'[A-F]' 79 | if pattern == loose_pattern: 80 | if model_answer == 'Yes.': 81 | return 'A' 82 | elif model_answer == 'No.': 83 | return 'B' 84 | else: 85 | return None 86 | else: 87 | return extract_and_choose_answer(loose_pattern, model_answer) 88 | 89 | max_count = max(option_count.values()) 90 | max_options = [option for option, count in option_count.items() if count == max_count] 91 | return max_options[0] 92 | 93 | 94 | 95 | def generate_score(result_path, score_path, wrong_item_path): 96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 97 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 98 | 99 | all = defaultdict(int) 100 | right = defaultdict(int) 101 | accuracy_dict = defaultdict(int) 102 | wrong_item = [] 103 | 104 | print(f'****Total:{len(json_objects)}****') 105 | debug = True 106 | for item in json_objects: 107 | source = item["source"] 108 | for answer in item["model_answer"]: 109 | all[source] += 1 110 | pattern = r'[(\(]([A-Fa-f])[)\)]' 111 | extract_answer = extract_and_choose_answer(pattern, answer) 112 | item['extract_answer'] = extract_answer 113 | if debug: 114 | debug = False 115 | print(f'extract_answer:{extract_answer}') 116 | right_answer = item['answer'] 117 | print(f'right_answer:{right_answer}') 118 | if item['answer'] == extract_answer: 119 | right[source] += 1 120 | else: 121 | wrong_item.append(item) 122 | 123 | 124 | print(f'all:{all}') 125 | print(f'right:{right}') 126 | 127 | for key in right: 128 | accuracy_dict[key] = right[key] / all[key] 129 | 130 | with open(score_path, "w", encoding="utf8") as f: 131 | json.dump(accuracy_dict, f, indent=4) 132 | 133 | print(f'***********score_result save in {score_path}*************') 134 | 135 | with open(wrong_item_path, "w", encoding="utf8") as f: 136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 137 | 138 | print(f'***********wrong_item save in {wrong_item_path}*************') 139 | 140 | 141 | def generate_response(args): 142 | accelerator = Accelerator() 143 | 144 | model_path = args.model_path 145 | accelerator.print(f'****************model_path:{model_path}******************') 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='<|endoftext|>') 148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 149 | model = model.half().cuda() 150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 151 | 152 | 153 | dataset = TestDataset(args.input_path, tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 155 | 156 | model = model.eval() 157 | if dist.is_initialized(): 158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 159 | 160 | dataloader = accelerator.prepare(dataloader) 161 | accelerator.print(f'******************load_model from {model_path}******************') 162 | 163 | if accelerator.is_main_process: 164 | fp = open(args.output_path,'w') 165 | 166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 167 | for batch in dataloader_iterator: 168 | batch_input_ids = batch["input_ids"] 169 | batch_data = batch["data"] 170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 172 | 173 | 174 | if dist.is_initialized(): 175 | all_batch_data = [None] * dist.get_world_size() 176 | all_batch_responses = [None] * dist.get_world_size() 177 | dist.all_gather_object(all_batch_responses, batch_responses) 178 | dist.all_gather_object(all_batch_data, batch_data) 179 | else: 180 | all_batch_data = [batch_data, ] 181 | all_batch_responses = [batch_responses, ] 182 | 183 | all_data = [item for sublist in all_batch_data for item in sublist] 184 | all_response = [item for sublist in all_batch_responses for item in sublist] 185 | 186 | for data, responses in zip(all_data, all_response): 187 | answer_list = [] 188 | for response in responses: 189 | answer_list.append(response) 190 | data['model_answer'] = answer_list 191 | if accelerator.is_main_process: 192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 193 | fp.flush() 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 198 | parser.add_argument("--input_path", type=str, help="path to the input data") 199 | parser.add_argument("--output_path", type=str, help="path to the output data") 200 | parser.add_argument("--score_path", type=str, help="path to the score") 201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 202 | parser.add_argument("--num_return", type=int, help="number of return sequences") 203 | parser.add_argument("--batch_size", type=int, help="batch size") 204 | args = parser.parse_args() 205 | generate_response(args) 206 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 207 | 208 | 209 | -------------------------------------------------------------------------------- /src/evaluate/eval_zephyr.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, data_path,tokenizer): 17 | self.data = [] 18 | with open(data_path) as f: 19 | self.data = json.load(f) 20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 21 | if dist_flag_0: 22 | print(f'load {len(self.data)} data from {data_path}') 23 | self.tokenizer = tokenizer 24 | self.debug = True 25 | 26 | def __getitem__(self, index): 27 | item = self.data[index] 28 | return { 29 | 'data': item, 30 | 'input': item['question'] 31 | } 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def collate_fn(self, batch): 37 | batch_query = [x['input'] for x in batch] 38 | batch_data = [x['data'] for x in batch] 39 | out_batch = {} 40 | out_batch['data'] = batch_data 41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids'] 42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False 43 | if self.debug and dist_flag_0: 44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False) 45 | for idx, sample in enumerate(decoded_texts): 46 | print(f'*******************batch_texts[{idx}]**********************************') 47 | print(sample) 48 | self.debug = False 49 | return out_batch 50 | 51 | 52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return): 53 | responses_list=[] 54 | batch_return=[] 55 | input_len = len(batch_input_ids[0]) 56 | for idx, output_ids in enumerate(batch_output_ids): 57 | generated_ids = output_ids[input_len:] 58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True)) 59 | if idx % num_return == num_return-1: 60 | responses_list.append(batch_return) 61 | batch_return=[] 62 | return responses_list 63 | 64 | def extract_and_choose_answer(pattern, model_answer): 65 | if '\n' in model_answer: 66 | model_answer_split = model_answer.split('\n') 67 | for model_answer_i in model_answer_split: 68 | if len(model_answer_i): 69 | model_answer = model_answer_i 70 | break 71 | 72 | matches = re.findall(pattern, model_answer) 73 | option_count = {} 74 | for match in matches: 75 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 76 | 77 | if not option_count: 78 | # else use loose pattern 79 | loose_pattern = r'[A-F]' 80 | if pattern == loose_pattern: 81 | if model_answer == 'Yes.': 82 | return 'A' 83 | elif model_answer == 'No.': 84 | return 'B' 85 | else: 86 | return None 87 | else: 88 | return extract_and_choose_answer(loose_pattern, model_answer) 89 | 90 | max_count = max(option_count.values()) 91 | max_options = [option for option, count in option_count.items() if count == max_count] 92 | return max_options[0] 93 | 94 | 95 | 96 | def generate_score(result_path, score_path, wrong_item_path): 97 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 98 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 99 | 100 | all = defaultdict(int) 101 | right = defaultdict(int) 102 | accuracy_dict = defaultdict(int) 103 | wrong_item = [] 104 | 105 | print(f'****Total:{len(json_objects)}****') 106 | debug = True 107 | for item in json_objects: 108 | source = item["source"] 109 | for answer in item["model_answer"]: 110 | all[source] += 1 111 | pattern = r'[(\(]([A-Fa-f])[)\)]' 112 | extract_answer = extract_and_choose_answer(pattern, answer) 113 | item['extract_answer'] = extract_answer 114 | if debug: 115 | debug = False 116 | print(f'extract_answer:{extract_answer}') 117 | right_answer = item['answer'] 118 | print(f'right_answer:{right_answer}') 119 | if item['answer'] == extract_answer: 120 | right[source] += 1 121 | else: 122 | wrong_item.append(item) 123 | 124 | 125 | print(f'all:{all}') 126 | print(f'right:{right}') 127 | 128 | for key in right: 129 | accuracy_dict[key] = right[key] / all[key] 130 | 131 | with open(score_path, "w", encoding="utf8") as f: 132 | json.dump(accuracy_dict, f, indent=4) 133 | 134 | print(f'***********score_result save in {score_path}*************') 135 | 136 | with open(wrong_item_path, "w", encoding="utf8") as f: 137 | json.dump(wrong_item, f, indent=4, ensure_ascii=False) 138 | 139 | print(f'***********wrong_item save in {wrong_item_path}*************') 140 | 141 | 142 | def generate_response(args): 143 | accelerator = Accelerator() 144 | 145 | model_path = args.model_path 146 | accelerator.print(f'****************model_path:{model_path}******************') 147 | 148 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='') 149 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 150 | model = model.half().cuda() 151 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False} 152 | 153 | 154 | dataset = TestDataset(args.input_path, tokenizer) 155 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) 156 | 157 | model = model.eval() 158 | if dist.is_initialized(): 159 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************') 160 | 161 | dataloader = accelerator.prepare(dataloader) 162 | accelerator.print(f'******************load_model from {model_path}******************') 163 | 164 | if accelerator.is_main_process: 165 | fp = open(args.output_path,'w') 166 | 167 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader 168 | for batch in dataloader_iterator: 169 | batch_input_ids = batch["input_ids"] 170 | batch_data = batch["data"] 171 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs) 172 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return) 173 | 174 | 175 | if dist.is_initialized(): 176 | all_batch_data = [None] * dist.get_world_size() 177 | all_batch_responses = [None] * dist.get_world_size() 178 | dist.all_gather_object(all_batch_responses, batch_responses) 179 | dist.all_gather_object(all_batch_data, batch_data) 180 | else: 181 | all_batch_data = [batch_data, ] 182 | all_batch_responses = [batch_responses, ] 183 | 184 | all_data = [item for sublist in all_batch_data for item in sublist] 185 | all_response = [item for sublist in all_batch_responses for item in sublist] 186 | 187 | for data, responses in zip(all_data, all_response): 188 | answer_list = [] 189 | for response in responses: 190 | answer_list.append(response) 191 | data['model_answer'] = answer_list 192 | if accelerator.is_main_process: 193 | fp.write(json.dumps(data, ensure_ascii=False) +'\n') 194 | fp.flush() 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 199 | parser.add_argument("--input_path", type=str, help="path to the input data") 200 | parser.add_argument("--output_path", type=str, help="path to the output data") 201 | parser.add_argument("--score_path", type=str, help="path to the score") 202 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item") 203 | parser.add_argument("--num_return", type=int, help="number of return sequences") 204 | parser.add_argument("--batch_size", type=int, help="batch size") 205 | args = parser.parse_args() 206 | # generate_response(args) 207 | generate_score(args.output_path, args.score_path, args.wrong_item_path) 208 | 209 | 210 | -------------------------------------------------------------------------------- /src/evaluate/generate_score.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 5 | import re 6 | import argparse 7 | from accelerate import Accelerator 8 | import json 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from collections import defaultdict 13 | import json 14 | 15 | def extract_and_choose_answer(pattern, model_answer): 16 | if '\n' in model_answer: 17 | model_answer_split = model_answer.split('\n') 18 | for model_answer_i in model_answer_split: 19 | if len(model_answer_i): 20 | model_answer = model_answer_i 21 | break 22 | 23 | matches = re.findall(pattern, model_answer) 24 | option_count = {} 25 | for match in matches: 26 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1 27 | 28 | if not option_count: 29 | # else use loose pattern 30 | loose_pattern = r'[A-F]' 31 | if pattern == loose_pattern: 32 | if model_answer == 'Yes.': 33 | return 'A' 34 | elif model_answer == 'No.': 35 | return 'B' 36 | else: 37 | return None 38 | else: 39 | return extract_and_choose_answer(loose_pattern, model_answer) 40 | 41 | max_count = max(option_count.values()) 42 | max_options = [option for option, count in option_count.items() if count == max_count] 43 | return max_options[0] 44 | 45 | 46 | 47 | def generate_score(result_path, score_path): 48 | with open(result_path, 'r', encoding='utf-8') as jsonl_file: 49 | json_objects = [json.loads(line.strip()) for line in jsonl_file] 50 | 51 | all = defaultdict(int) 52 | right = defaultdict(int) 53 | accuracy_dict = defaultdict(int) 54 | 55 | print(f'****Total:{len(json_objects)}****') 56 | debug = True 57 | for item in json_objects: 58 | source = item["source"] 59 | answer = item["model_answer"] 60 | all[source] += 1 61 | pattern = r'[(\(]([A-Fa-f])[)\)]' 62 | extract_answer = extract_and_choose_answer(pattern, answer) 63 | if debug: 64 | debug = False 65 | print(f'extract_answer:{extract_answer}') 66 | print(answer) 67 | right_answer = item['answer'] 68 | print(f'right_answer:{right_answer}') 69 | if item['answer'] == extract_answer: 70 | right[source] += 1 71 | 72 | 73 | print(f'all:{all}') 74 | print(f'right:{right}') 75 | 76 | for key in right: 77 | accuracy_dict[key] = right[key] / all[key] 78 | 79 | with open(score_path, "w", encoding="utf8") as f: 80 | json.dump(accuracy_dict, f, indent=4) 81 | 82 | print(f'***********score_result save in {score_path}*************') 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--output_path", '-a', type=str, help="path to the output data") 88 | parser.add_argument("--score_path", '-o', type=str, help="path to the score") 89 | args = parser.parse_args() 90 | generate_score(args.output_path, args.score_path) 91 | # merge_result() 92 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/OpenAIGPT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import openai 3 | from retrying import retry 4 | import random 5 | 6 | class OpenAIGPT: 7 | def __init__(self, model_name="gpt-3.5-turbo", keys_path=None): 8 | self.model_name = model_name 9 | with open(keys_path, encoding="utf-8", mode="r") as fr: 10 | self.keys = [line.strip() for line in fr if len(line.strip()) >= 4] 11 | 12 | def __post_process(self, response): 13 | return response["choices"][0]["message"]["content"] 14 | 15 | @retry(wait_fixed=200, stop_max_attempt_number=50) 16 | def __call__(self, message): 17 | if message is None or message == "": 18 | return False, "Your input is empty." 19 | 20 | current_key = random.choice(self.keys) 21 | openai.api_key = current_key 22 | # openai.organization = "org-DvpMM9lXmpMxnaygs1ixEvZw" 23 | response = openai.ChatCompletion.create( 24 | model=self.model_name, 25 | messages=[{"role": "user", "content": message}], 26 | temperature=0.6, 27 | top_p=0.8, 28 | frequency_penalty=0.6, 29 | presence_penalty=0.8, 30 | n=1, 31 | ) 32 | return self.__post_process(response) 33 | 34 | 35 | if __name__ == "__main__": 36 | # test code 37 | igpt = OpenAIGPT(keys_path="gpt4key.txt", model_name="gpt-4") 38 | string = "Hello" 39 | print(string) 40 | answer = igpt(string) 41 | print(answer) 42 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/OpenAIGPT_datagen_multithread.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | from concurrent.futures import ThreadPoolExecutor 3 | from tqdm import tqdm 4 | import os 5 | import argparse 6 | from OpenAIGPT import OpenAIGPT 7 | 8 | 9 | def OpenAIGPT_datagen(args): 10 | igpt = OpenAIGPT(model_name=args.model_name, keys_path=args.keys_path) 11 | 12 | def process_item(item): 13 | content = igpt(item["query"]) 14 | item["model_answer"] = content 15 | return item 16 | 17 | output_path = args.output_path 18 | input_path = args.input_path 19 | 20 | # Collect the IDs of processed items in the output file 21 | processed_ids = set() 22 | if os.path.exists(output_path): 23 | with jsonlines.open(output_path, "r") as f: 24 | for item in f: 25 | processed_ids.add(item.get("id", None)) 26 | 27 | # Collect unprocessed items 28 | items_to_process = [] 29 | 30 | with jsonlines.open(input_path, "r") as reader: 31 | for item in reader: 32 | item_id = item.get("id", None) 33 | if item_id is not None and item_id in processed_ids: 34 | continue 35 | items_to_process.append(item) 36 | 37 | # Multi-threaded parallel processing 38 | with jsonlines.open( 39 | output_path, "a" if os.path.exists(output_path) else "w" 40 | ) as writer: 41 | with ThreadPoolExecutor(max_workers=args.max_workers) as executor: 42 | futures = { 43 | executor.submit(process_item, item): item for item in items_to_process 44 | } 45 | 46 | # Use tqdm to display progress 47 | for future in tqdm( 48 | futures, total=len(items_to_process), desc="Processing items" 49 | ): 50 | item = future.result() 51 | writer.write(item) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser(description="Process JSONL files concurrently.") 56 | parser.add_argument( 57 | "--model_name", 58 | type=str, 59 | default="gpt-3.5-turbo", 60 | help="Name of the OpenAIGPT model to use.", 61 | ) 62 | parser.add_argument( 63 | "--keys_path", 64 | type=str, 65 | required=True, 66 | help="API key for the OpenAIGPT service.", 67 | ) 68 | parser.add_argument( 69 | "--input_path", type=str, required=True, help="Path to the input JSONL file." 70 | ) 71 | parser.add_argument( 72 | "--output_path", type=str, required=True, help="Path to the output JSONL file." 73 | ) 74 | parser.add_argument( 75 | "--max_workers", 76 | type=int, 77 | default=10, 78 | help="Maximum number of workers for concurrent processing.", 79 | ) 80 | 81 | args = parser.parse_args() 82 | OpenAIGPT_datagen(args) 83 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/gpt_key.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/src/process/openai_rewrite/gpt_key.txt -------------------------------------------------------------------------------- /src/process/openai_rewrite/guidelines_en/1.2.prepare_data.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import argparse 4 | 5 | ans_prompt = """You are Medbase, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's in English. In formulating your response, you must thoughtfully reference the , ensuring that your reply does not disclose your reliance on . Aim to provide a comprehensive and informative response, incorporating relevant insights from to best assist the user. Please be cautious to avoid including any content that might raise ethical concerns. 6 | 7 | : {question} 8 | 9 | : {reference} 10 | 11 | : """ 12 | 13 | 14 | def generate_query(data): 15 | chatgpt_query = ans_prompt.format_map( 16 | {"question": data["model_answer"], "reference": data["reference"]} 17 | ) 18 | return chatgpt_query 19 | 20 | 21 | def Prepare_data(args): 22 | data = [] 23 | # Read the uploaded JSONl file 24 | with jsonlines.open(args.input_path, "r") as reader: 25 | data = list(reader) 26 | 27 | print(f"len:{len(data)}") 28 | # Convert as required 29 | jsonl_data = [] 30 | 31 | for id, item in enumerate(data): 32 | jsonl_data.append( 33 | { 34 | "id": id, 35 | "query": generate_query(item), 36 | "model_answer": "", 37 | "model_question": item["model_answer"], 38 | "reference": item["reference"], 39 | } 40 | ) 41 | 42 | # Save the converted data as a JSONL file 43 | with open(args.output_path, "w", encoding="utf-8") as file: 44 | for entry in jsonl_data: 45 | file.write(json.dumps(entry, ensure_ascii=False) + "\n") 46 | 47 | print(f"Prepare finished, output to '{args.output_path}'") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="Prepare data for OpenAIGPT generation" 53 | ) 54 | parser.add_argument( 55 | "--input_path", type=str, required=True, help="Path to the input JSON file." 56 | ) 57 | parser.add_argument( 58 | "--output_path", type=str, required=True, help="Path to the output JSONL file." 59 | ) 60 | args = parser.parse_args() 61 | Prepare_data(args) 62 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/guidelines_en/1.prepare_data.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import argparse 4 | 5 | query_prompt = """Please create a that closely aligns with the provided . Ensure that the is formulated in English and does not explicitly reference the text. You may incorporate specific scenarios or contexts in the , allowing the to serve as a comprehensive and precise answer. 6 | 7 | : {text} 8 | 9 | : """ 10 | 11 | 12 | def generate_query(data): 13 | chatgpt_query = query_prompt.format_map({"text": data[0]}) 14 | return chatgpt_query 15 | 16 | 17 | def Prepare_data(args): 18 | data = [] 19 | # Read the uploaded JSONl file 20 | with jsonlines.open(args.input_path, "r") as reader: 21 | data = list(reader) 22 | 23 | print(f"len:{len(data)}") 24 | # Convert as required 25 | jsonl_data = [] 26 | 27 | for id, item in enumerate(data): 28 | jsonl_data.append( 29 | { 30 | "id": id, 31 | "query": generate_query(item), 32 | "model_answer": "", 33 | "reference": item[0], 34 | } 35 | ) 36 | 37 | # Save the converted data as a JSONL file 38 | with open(args.output_path, "w", encoding="utf-8") as file: 39 | for entry in jsonl_data: 40 | file.write(json.dumps(entry, ensure_ascii=False) + "\n") 41 | 42 | print(f"Prepare finished, output to '{args.output_path}'") 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser( 47 | description="Prepare data for OpenAIGPT generation" 48 | ) 49 | parser.add_argument( 50 | "--input_path", type=str, required=True, help="Path to the input JSON file." 51 | ) 52 | parser.add_argument( 53 | "--output_path", type=str, required=True, help="Path to the output JSONL file." 54 | ) 55 | args = parser.parse_args() 56 | Prepare_data(args) 57 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/guidelines_en/1.run_prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python 1.prepare_data.py --input_path ./data/1.dev.jsonl --output_path ./data/2.dev_prepared.jsonl 4 | python 1.2.prepare_data.py --input_path ./data/3.dev_aftgpt.jsonl --output_path ./data/2.dev_prepared.jsonl -------------------------------------------------------------------------------- /src/process/openai_rewrite/guidelines_en/2.run_gpt_datagen_multithread.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_prepared.jsonl --output_path ./data/3.dev_aftgpt.jsonl --max_workers 300 4 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_aftgpt_prepared.jsonl --output_path ./data/3.dev_aftgpt_prepared_aftgpt.jsonl --max_workers 300 -------------------------------------------------------------------------------- /src/process/openai_rewrite/guidelines_en/3.extract.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def extract_and_save(input_jsonl, output_json): 5 | # Create an empty list to store the extracted data 6 | extracted_data = [] 7 | train_data = [] 8 | count = 0 9 | # Read the JSONL file line by line and extract the specified fields 10 | with open(input_jsonl, "r", encoding="utf-8") as file: 11 | for line in file: 12 | json_data = json.loads(line) 13 | if len(json_data.get("model_answer", "")) < 32: 14 | continue 15 | 16 | #Extract specified fields 17 | extracted_item = { 18 | "question": json_data.get("model_question", ""), 19 | "answer": json_data.get("model_answer", ""), 20 | "reference": json_data.get("reference", ""), 21 | "id": json_data.get("id", ""), 22 | } 23 | list_item = [ 24 | json_data.get("model_question", ""), 25 | json_data.get("model_answer", ""), 26 | ] 27 | 28 | # Add the extracted data to the list 29 | extracted_data.append(extracted_item) 30 | train_data.append(list_item) 31 | count += 1 32 | 33 | # Write the extracted data into a new JSON file 34 | print(f"sum_count:{count}") 35 | with open(output_json, "w", encoding="utf-8") as output: 36 | json.dump(extracted_data, output, indent=2, ensure_ascii=False) 37 | with open(train_json, "w", encoding="utf-8") as output: 38 | json.dump(train_data, output, indent=2, ensure_ascii=False) 39 | 40 | 41 | #Specify the input JSONL file and output JSON file name 42 | input_jsonl = "./data/3.dev_aftgpt_prepared_aftgpt.jsonl" 43 | output_json = "./data/4.dev_extracted.json" 44 | train_json = "./data/4.dev_train.json" 45 | 46 | # Call the extraction and saving functions 47 | extract_and_save(input_jsonl, output_json) 48 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/patient_en/1.2.prepare_data.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import argparse 4 | 5 | ans_prompt = """You are Medbase, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's in English. In formulating your response, you must thoughtfully reference the , ensuring that your reply does not disclose your reliance on . Aim to provide a comprehensive and informative response, incorporating relevant insights from to best assist the user. Please be cautious to avoid including any content that might raise ethical concerns. 6 | 7 | : {question} 8 | 9 | : {reference} 10 | 11 | : """ 12 | 13 | 14 | def generate_query(data): 15 | chatgpt_query = ans_prompt.format_map( 16 | {"question": data["model_answer"], "reference": data["reference"]} 17 | ) 18 | return chatgpt_query 19 | 20 | 21 | def Prepare_data(args): 22 | data = [] 23 | # Read the uploaded JSONl file 24 | with jsonlines.open(args.input_path, "r") as reader: 25 | data = list(reader) 26 | 27 | print(f"len:{len(data)}") 28 | # Convert as required 29 | jsonl_data = [] 30 | 31 | for id, item in enumerate(data): 32 | jsonl_data.append( 33 | { 34 | "id": id, 35 | "query": generate_query(item), 36 | "model_answer": "", 37 | "model_question": item["model_answer"], 38 | "reference": item["reference"], 39 | } 40 | ) 41 | 42 | # Save the converted data as a JSONL file 43 | with open(args.output_path, "w", encoding="utf-8") as file: 44 | for entry in jsonl_data: 45 | file.write(json.dumps(entry, ensure_ascii=False) + "\n") 46 | 47 | print(f"Prepare finished, output to '{args.output_path}'") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="Prepare data for OpenAIGPT generation" 53 | ) 54 | parser.add_argument( 55 | "--input_path", type=str, required=True, help="Path to the input JSON file." 56 | ) 57 | parser.add_argument( 58 | "--output_path", type=str, required=True, help="Path to the output JSONL file." 59 | ) 60 | args = parser.parse_args() 61 | Prepare_data(args) 62 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/patient_en/1.prepare_data.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import argparse 4 | 5 | query_prompt = """ 6 | {text} 7 | Please create some dialogues between patients and doctors in English based on the above text. The format is: 8 | Patient’s question 9 | Doctor’s answer 10 | Both patient questions and doctor responses are as complex and detailed as possible.""" 11 | 12 | 13 | def generate_query(data): 14 | chatgpt_query = query_prompt.format_map({"text": data[0]}) 15 | return chatgpt_query 16 | 17 | 18 | def Prepare_data(args): 19 | data = [] 20 | # Read the uploaded JSONl file 21 | with jsonlines.open(args.input_path, "r") as reader: 22 | data = list(reader) 23 | 24 | print(f"len:{len(data)}") 25 | # Convert as required 26 | jsonl_data = [] 27 | 28 | for id, item in enumerate(data): 29 | query = generate_query(item) 30 | if len(query) > 4090: 31 | continue 32 | jsonl_data.append( 33 | { 34 | "id": id, 35 | "query": generate_query(item), 36 | "model_answer": "", 37 | "reference": item[0], 38 | } 39 | ) 40 | 41 | # Save the converted data as a JSONL file 42 | with open(args.output_path, "w", encoding="utf-8") as file: 43 | for entry in jsonl_data: 44 | file.write(json.dumps(entry, ensure_ascii=False) + "\n") 45 | 46 | print(f"Prepare finished, output to '{args.output_path}'") 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser( 51 | description="Prepare data for OpenAIGPT generation" 52 | ) 53 | parser.add_argument( 54 | "--input_path", type=str, required=True, help="Path to the input JSON file." 55 | ) 56 | parser.add_argument( 57 | "--output_path", type=str, required=True, help="Path to the output JSONL file." 58 | ) 59 | args = parser.parse_args() 60 | Prepare_data(args) 61 | -------------------------------------------------------------------------------- /src/process/openai_rewrite/patient_en/1.run_prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the paths for input and output files 4 | input_path="./data/1.rlhf.jsonl" 5 | output_path="./data/2.rlhf_prepared.jsonl" 6 | 7 | # run Python scripts 8 | python 1.prepare_data.py --input_path "$input_path" --output_path "$output_path" 9 | 10 | python 1.prepare_data.py --input_path ./data/1.dev_en.jsonl --output_path ./data/2.dev_prepared.jsonl 11 | python 1.2.prepare_data.py --input_path ./data/3.dev_aftgpt.jsonl --output_path ./data/2.dev_aftgpt_prepared.jsonl -------------------------------------------------------------------------------- /src/process/openai_rewrite/patient_en/2.run_gpt_datagen_multithread.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python "$python_script" --keys_path "$keys_path" --input_path "$input_path" --output_path "$output_path" --max_workers $max_workers 4 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_prepared.jsonl --output_path ./data/3.dev_aftgpt.jsonl --max_workers 300 5 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_aftgpt_prepared.jsonl --output_path ./data/3.dev_aftgpt_prepared_aftgpt.jsonl --max_workers 300 -------------------------------------------------------------------------------- /src/process/openai_rewrite/patient_en/3.extract.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | 5 | def extract_dialogues_from_model_answer(model_answer): 6 | dialogues = [] 7 | try: 8 | # Use regular expressions to extract conversations between patients and doctors 9 | patient_dialogues = re.findall( 10 | r"(.*?)", model_answer, re.DOTALL 11 | ) 12 | doctor_dialogues = re.findall( 13 | r"(.*?)", model_answer, re.DOTALL 14 | ) 15 | 16 | # Combine conversations into a list 17 | dialogues = list(zip(patient_dialogues, doctor_dialogues)) 18 | final_list = [] 19 | for item in dialogues: 20 | for i in item: 21 | final_list.append(i) 22 | except Exception as e: 23 | print(f"Error extracting dialogues: {str(e)}") 24 | 25 | return final_list 26 | 27 | 28 | def process_jsonl_file(input_file, output_file): 29 | result_data = [] 30 | 31 | try: 32 | with open(input_file, "r", encoding="utf-8") as file: 33 | #Read JSONL file line by line 34 | for line in file: 35 | json_data = json.loads(line) 36 | 37 | # Check whether the JSON contains the "model_answer" field 38 | if "model_answer" in json_data: 39 | model_answer = json_data["model_answer"] 40 | 41 | # Extract conversations and add to result list 42 | dialogues = extract_dialogues_from_model_answer(model_answer) 43 | if not len(dialogues): 44 | continue 45 | result_data.append(dialogues) 46 | 47 | except Exception as e: 48 | print(f"Error processing JSONL file: {str(e)}") 49 | 50 | # 将结果保存为JSON文件 51 | with open(output_file, "w", encoding="utf-8") as output_file: 52 | json.dump(result_data, output_file, ensure_ascii=False, indent=2) 53 | 54 | 55 | # 用法示例 56 | input_file_path = "./data/3.patients_en_aftgpt.jsonl" 57 | output_file_path = "./data/4.patients_en_aftgpt.json" 58 | 59 | process_jsonl_file(input_file_path, output_file_path) 60 | -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_gemma.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_huatuo2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | if len(question) + real_question_len < 4096: 149 | question += few_shot_prompt.format(**sample) 150 | question += real_question 151 | if len(question)>4096: 152 | continue 153 | if debug: 154 | print(question) 155 | debug=False 156 | 157 | tmp_dict['source_question'] = item['question'] 158 | tmp_dict['source_option'] = item['options'] 159 | tmp_dict['question'] = question 160 | tmp_dict['answer'] = item['answer'][1] 161 | tmp_dict['source'] = item['source'] 162 | data_final.append(tmp_dict) 163 | 164 | with open(args.save_path, 'w', encoding='utf-8') as file: 165 | json.dump(data_final, file, ensure_ascii=False, indent=2) 166 | 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 170 | 171 | # Model Args 172 | parser.add_argument('--save_path', default='', type=str) 173 | parser.add_argument('--data_path', default='', type=str) 174 | parser.add_argument('--few_shot', default='', type=int) 175 | args = parser.parse_args() 176 | 177 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_llama.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_meditron.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | 155 | if debug: 156 | print(question) 157 | debug=False 158 | 159 | tmp_dict['source_question'] = item['question'] 160 | tmp_dict['source_option'] = item['options'] 161 | tmp_dict['question'] = question 162 | tmp_dict['answer'] = item['answer'][1] 163 | tmp_dict['source'] = item['source'] 164 | data_final.append(tmp_dict) 165 | 166 | with open(args.save_path, 'w', encoding='utf-8') as file: 167 | json.dump(data_final, file, ensure_ascii=False, indent=2) 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 172 | 173 | # Model Args 174 | parser.add_argument('--save_path', default='', type=str) 175 | parser.add_argument('--data_path', default='', type=str) 176 | parser.add_argument('--few_shot', default='', type=int) 177 | args = parser.parse_args() 178 | 179 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_mistral.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_qwen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}.<|endoftext|> 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}.<|endoftext|> 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}.<|endoftext|> 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}.<|endoftext|> 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}.<|endoftext|> 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}.<|endoftext|> 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي.<|endoftext|> 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_yi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}.<|endoftext|> 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}.<|endoftext|> 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}.<|endoftext|> 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}.<|endoftext|> 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}.<|endoftext|> 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}.<|endoftext|> 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي.<|endoftext|> 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/process/prepare/data_process_test_zephyr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 8 | 9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 10 | Question: {question} 11 | Options: 12 | {options} 13 | Assistant:The correct answer is {answer}. 14 | """ 15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D. 16 | Question: {question} 17 | Options: 18 | {options} 19 | Assistant:""" 20 | 21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 22 | Context: {context} 23 | Question: {question} 24 | Options: 25 | {options} 26 | Assistant:The correct answer is {answer}. 27 | """ 28 | 29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer. 30 | Context: {context} 31 | Question: {question} 32 | Options: 33 | {options} 34 | Assistant:""" 35 | 36 | 37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 38 | 问题: {question} 39 | 选项: 40 | {options} 41 | Assistant:正确答案是{answer}. 42 | """ 43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。 44 | 问题: {question} 45 | 选项: 46 | {options} 47 | Assistant:""" 48 | 49 | 50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 51 | pregunta: {question} 52 | Opciones: 53 | {options} 54 | Assistant:La respuesta correcta es {answer}. 55 | """ 56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D. 57 | pregunta: {question} 58 | Opciones: 59 | {options} 60 | Assistant:""" 61 | 62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 63 | question: {question} 64 | Possibilités: 65 | {options} 66 | Assistant:La bonne réponse est {answer}. 67 | """ 68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E. 69 | question: {question} 70 | Possibilités: 71 | {options} 72 | Assistant:""" 73 | 74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 75 | सवाल: {question} 76 | विकल्प: 77 | {options} 78 | Assistant:सही उत्तर है{answer}. 79 | """ 80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें। 81 | सवाल: {question} 82 | विकल्प: 83 | {options} 84 | Assistant:""" 85 | 86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 87 | سؤال: {question} 88 | خيارات: 89 | {options} 90 | Assistant:{answer}الإجابة الصحيحة هي. 91 | """ 92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د. 93 | سؤال: {question} 94 | خيارات: 95 | {options} 96 | Assistant:""" 97 | 98 | 99 | 100 | def preprocess(args): 101 | data_final = [] 102 | with open(args.data_path, 'r') as file: 103 | data = json.load(file) 104 | grouped_items = {} 105 | for item in data: 106 | source = item.get("source") 107 | if source not in grouped_items: 108 | grouped_items[source] = [] 109 | grouped_items[source].append(item) 110 | 111 | for source, items in grouped_items.items(): 112 | debug = True 113 | print(f'*********************{source}****************************') 114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']: 115 | few_shot_prompt = question_prompt_zh_choice_shot 116 | question_prompt = question_prompt_en_choice 117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']: 118 | few_shot_prompt = question_prompt_en_choice_shot 119 | question_prompt = question_prompt_en_choice 120 | elif source in ['headqa']: 121 | few_shot_prompt = question_prompt_es_choice_shot 122 | question_prompt = question_prompt_es_choice 123 | elif source in ['frenchmedmcqa']: 124 | few_shot_prompt = question_prompt_fr_choice_shot 125 | question_prompt = question_prompt_fr_choice 126 | elif source in ['mmlu-medical-ar']: 127 | few_shot_prompt = question_prompt_ar_choice_shot 128 | question_prompt = question_prompt_ar_choice 129 | elif source in ['mmlu-medical-hi']: 130 | few_shot_prompt = question_prompt_hi_choice_shot 131 | question_prompt = question_prompt_hi_choice 132 | else: 133 | few_shot_prompt = question_prompt_en_pubmed_shot 134 | question_prompt = question_prompt_en_pubmed 135 | 136 | for item in items: 137 | random_samples = random.sample(items, args.few_shot+1) 138 | question = '' 139 | tmp_dict = {} 140 | # in case item in random_samples 141 | if item in random_samples: 142 | random_samples.remove(item) 143 | else: 144 | random_samples = random_samples[:-1] 145 | real_question = question_prompt.format(**item) 146 | real_question_len = len(real_question) 147 | for sample in random_samples: 148 | sample = few_shot_prompt.format(**sample) 149 | if len(question) + real_question_len + len(sample) < 4096: 150 | question += sample 151 | question += real_question 152 | if len(question)>4096: 153 | continue 154 | if debug: 155 | print(question) 156 | debug=False 157 | 158 | tmp_dict['source_question'] = item['question'] 159 | tmp_dict['source_option'] = item['options'] 160 | tmp_dict['question'] = question 161 | tmp_dict['answer'] = item['answer'][1] 162 | tmp_dict['source'] = item['source'] 163 | data_final.append(tmp_dict) 164 | 165 | with open(args.save_path, 'w', encoding='utf-8') as file: 166 | json.dump(data_final, file, ensure_ascii=False, indent=2) 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 171 | 172 | # Model Args 173 | parser.add_argument('--save_path', default='', type=str) 174 | parser.add_argument('--data_path', default='', type=str) 175 | parser.add_argument('--few_shot', default='', type=int) 176 | args = parser.parse_args() 177 | 178 | preprocess(args) -------------------------------------------------------------------------------- /src/proxy-tuning/eval/apollodata/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | import json 7 | from eval.utils import ( 8 | ensure_dir, 9 | generate_completions, 10 | load_hf_lm_and_tokenizer, 11 | load_mexperts_model_and_tokenizer, 12 | ) 13 | from transformers import AutoConfig 14 | 15 | def main(args): 16 | ensure_dir(args.save_dir) 17 | 18 | if args.model_name_or_path: 19 | print("Loading model and tokenizer...") 20 | model, tokenizer = load_hf_lm_and_tokenizer( 21 | model_name_or_path=args.model_name_or_path, 22 | tokenizer_name_or_path=args.tokenizer_name_or_path, 23 | load_in_8bit=args.load_in_8bit, 24 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto", 25 | use_fast_tokenizer=not args.use_slow_tokenizer, 26 | ) 27 | elif args.base_model_name_or_path: 28 | model, tokenizer = load_mexperts_model_and_tokenizer( 29 | args.base_model_name_or_path, 30 | args.expert_model_name_or_path, 31 | args.antiexpert_model_name_or_path, 32 | model_type=args.model_type, 33 | load_in_8bit=args.load_in_8bit, 34 | use_fast_tokenizer=not args.use_slow_tokenizer, 35 | ) 36 | 37 | # use dev set because test set answers are hidden 38 | test_df = pd.read_json(os.path.join(args.data_dir, "test.json")) 39 | 40 | # Create prompts 41 | prompts = [] 42 | for i, row in test_df.iterrows(): 43 | prompts.append({'question': row["question"], 'source': row["source"], 'answer': row["answer"]}) 44 | 45 | new_line_token = tokenizer.encode("\n\n", add_special_tokens=False)[-1] 46 | outputs = generate_completions( 47 | model, 48 | tokenizer, 49 | prompts, 50 | batch_size=args.eval_batch_size, 51 | do_sample=False, 52 | max_new_tokens=20, 53 | stop_id_sequences=[[new_line_token]], 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument( 60 | "--data_dir", 61 | type=str, 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--save_dir", 66 | type=str, 67 | default=None, 68 | ) 69 | parser.add_argument( 70 | "--model_name_or_path", 71 | type=str, 72 | default=None, 73 | help="if specified, we will load the model to generate the predictions." 74 | ) 75 | parser.add_argument( 76 | "--tokenizer_name_or_path", 77 | type=str, 78 | default=None, 79 | help="if specified, we will load the tokenizer from here." 80 | ) 81 | parser.add_argument( 82 | "--use_slow_tokenizer", 83 | action="store_true", 84 | help="If given, we will use the slow tokenizer." 85 | ) 86 | parser.add_argument( 87 | "--max_examples", 88 | type=int, 89 | help="if specified, a maximum of max_examples for evaluation" 90 | ) 91 | parser.add_argument( 92 | "--eval_batch_size", 93 | type=int, 94 | default=1, 95 | help="batch size for evaluation." 96 | ) 97 | parser.add_argument( 98 | "--load_in_8bit", 99 | action="store_true", 100 | help="load model in 8bit mode, which will reduce memory and speed up inference." 101 | ) 102 | parser.add_argument( 103 | "--base_model_name_or_path", 104 | type=str, 105 | default=None, 106 | ) 107 | parser.add_argument( 108 | "--expert_model_name_or_path", 109 | type=str, 110 | default=None, 111 | ) 112 | parser.add_argument( 113 | "--antiexpert_model_name_or_path", 114 | type=str, 115 | default=None, 116 | ) 117 | parser.add_argument( 118 | "--output_path", 119 | type=str, 120 | default='../../outputs-qwen7b.jsonl', 121 | ) 122 | 123 | args = parser.parse_args() 124 | model_config = AutoConfig.from_pretrained(args.base_model_name_or_path,trust_remote_code=True) 125 | args.model_type = model_config.model_type 126 | main(args) 127 | -------------------------------------------------------------------------------- /src/proxy-tuning/scripts/eval/proxy_tuning.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | 3 | # Evaluating apollodata with qwen7b expert 4 | size=14 5 | echo "Results dir: results/apollodata" 6 | CUDA_LAUNCH_BLOCKING=1 python -m eval.apollodata.run_eval \ 7 | --data_dir data/eval/apollodata/ \ 8 | --save_dir results/apollodata \ 9 | --base_model_name_or_path /your_model_path/Qwen-7B/ \ 10 | --expert_model_name_or_path /your_model_path/Apollo-1.8B/ \ 11 | --antiexpert_model_name_or_path /your_model_path/Qwen-1.5-1.8B/ \ 12 | --output_path ../qwen7b.jsonl \ 13 | --eval_batch_size 1 14 | -------------------------------------------------------------------------------- /src/sft/training_config/zero.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: False 7 | zero3_save_16bit_model: true 8 | zero_stage: 1 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'yes' 11 | main_training_function: main 12 | mixed_precision: bf16 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | use_cpu: true -------------------------------------------------------------------------------- /src/sft/training_config/zero_multi.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: False 7 | zero3_save_16bit_model: true 8 | zero_stage: 3 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'yes' 11 | main_training_function: main 12 | mixed_precision: bf16 13 | num_machines: 8 14 | num_processes: 40 15 | rdzv_backend: static 16 | same_network: true 17 | use_cpu: False -------------------------------------------------------------------------------- /utils/check.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "None\n", 13 | "0\n", 14 | "\n", 15 | "\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel\n", 21 | "import transformers\n", 22 | "\n", 23 | "model_path = '/your_data_path/PMC_LLaMA_7B'\n", 24 | "tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)\n", 25 | "\n", 26 | "# string = 'hello'\n", 27 | "# ids= tokenizer.encode(string)\n", 28 | "# print(ids)\n", 29 | "# for id in ids:\n", 30 | "# print(tokenizer.decode(id))\n", 31 | "print(tokenizer.pad_token_id)\n", 32 | "print(tokenizer.eos_token_id)\n", 33 | "# print(tokenizer.decode(tokenizer.pad_token_id))\n", 34 | "# print(tokenizer.decode(0))\n", 35 | "# print(tokenizer.decode(1))\n", 36 | "\n" 37 | ] 38 | } 39 | ], 40 | "metadata": { 41 | "kernelspec": { 42 | "display_name": "Python 3 (ipykernel)", 43 | "language": "python", 44 | "name": "python3" 45 | }, 46 | "language_info": { 47 | "codemirror_mode": { 48 | "name": "ipython", 49 | "version": 3 50 | }, 51 | "file_extension": ".py", 52 | "mimetype": "text/x-python", 53 | "name": "python", 54 | "nbconvert_exporter": "python", 55 | "pygments_lexer": "ipython3", 56 | "version": "3.10.13" 57 | } 58 | }, 59 | "nbformat": 4, 60 | "nbformat_minor": 4 61 | } 62 | -------------------------------------------------------------------------------- /utils/kill.sh: -------------------------------------------------------------------------------- 1 | pkill -f ./src/sft/train_yi_resume_val.py 2 | pkill -f eval_qwen.py --------------------------------------------------------------------------------