├── .gitignore
├── 0.download_data.sh
├── 1.data_process_test&dev.sh
├── 2.data_process_train.sh
├── 3.single_node_train_gemma.sh
├── 4.eval.sh
├── LICENSE.txt
├── README.md
├── assets
├── apollo_medium_final.png
├── dataset.png
├── final.png
└── result.png
├── metadata
├── dev.json
├── dev
│ ├── ar.json
│ ├── en.json
│ ├── es.json
│ ├── fr.json
│ ├── hi.json
│ └── zh.json
├── merge_json_train.py
├── test.json
└── test
│ ├── ar.json
│ ├── en.json
│ ├── es.json
│ ├── fr.json
│ ├── hi.json
│ └── zh.json
├── requirements.txt
├── scripts
├── 3.multinode_train_gema7B_rank0.sh
├── 3.multinode_train_gema7B_rank1.sh
└── 3.multinode_train_gema7B_rank2.sh
├── src
├── evaluate
│ ├── cli_demo.py
│ ├── eval_72b_34b.py
│ ├── eval_gemma.py
│ ├── eval_huatuo2.py
│ ├── eval_llama2.py
│ ├── eval_llama70b.py
│ ├── eval_meditron.py
│ ├── eval_meditron70b.py
│ ├── eval_mistral.py
│ ├── eval_mmedlm2.py
│ ├── eval_qwen.py
│ ├── eval_yi.py
│ ├── eval_zephyr.py
│ └── generate_score.py
├── process
│ ├── openai_rewrite
│ │ ├── OpenAIGPT.py
│ │ ├── OpenAIGPT_datagen_multithread.py
│ │ ├── gpt_key.txt
│ │ ├── guidelines_en
│ │ │ ├── 1.2.prepare_data.py
│ │ │ ├── 1.prepare_data.py
│ │ │ ├── 1.run_prepare_data.sh
│ │ │ ├── 2.run_gpt_datagen_multithread.sh
│ │ │ ├── 3.extract.py
│ │ │ └── data
│ │ │ │ └── 1.dev.jsonl
│ │ └── patient_en
│ │ │ ├── 1.2.prepare_data.py
│ │ │ ├── 1.prepare_data.py
│ │ │ ├── 1.run_prepare_data.sh
│ │ │ ├── 2.run_gpt_datagen_multithread.sh
│ │ │ ├── 3.extract.py
│ │ │ └── data
│ │ │ └── 1.dev.jsonl
│ └── prepare
│ │ ├── data_process_test_gemma.py
│ │ ├── data_process_test_huatuo2.py
│ │ ├── data_process_test_llama.py
│ │ ├── data_process_test_meditron.py
│ │ ├── data_process_test_mistral.py
│ │ ├── data_process_test_qwen.py
│ │ ├── data_process_test_yi.py
│ │ ├── data_process_test_zephyr.py
│ │ ├── data_process_train_gemma.py
│ │ ├── data_process_train_qwen.py
│ │ └── data_process_train_yi.py
├── proxy-tuning
│ ├── eval
│ │ ├── apollodata
│ │ │ └── run_eval.py
│ │ └── utils.py
│ ├── modeling
│ │ └── mexperts.py
│ └── scripts
│ │ └── eval
│ │ └── proxy_tuning.sh
└── sft
│ ├── train_gemma_resume_val.py
│ ├── train_qwen_resume_val.py
│ ├── train_yi_resume_val.py
│ └── training_config
│ ├── zero.yaml
│ └── zero_multi.yaml
└── utils
├── check.ipynb
└── kill.sh
/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/.gitignore
--------------------------------------------------------------------------------
/0.download_data.sh:
--------------------------------------------------------------------------------
1 | # download ApolloCorpus
2 |
3 | cd metadata
4 | wget https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/resolve/main/ApolloCorpus.zip
5 | unzip ApolloCorpus.zip
6 |
7 | # Prepare Data for Mix training
8 | mkdir mixTrain
9 |
10 |
11 | cd train/pretrain
12 | # Mixtraining Only use QA pairs in Pretrain
13 | for file in *; do
14 | if [[ $file == *_qa.json ]]; then
15 | cp "$file" "../mixTrain/"
16 | fi
17 | done
18 | cd ../
19 |
20 | # copy all file from sft to mix_train
21 | mv sft/* mixTrain/
22 |
23 | # merge all the file from mix_train directory to json
24 | python merge_json_train.py
25 | cd ../
26 |
27 |
28 |
--------------------------------------------------------------------------------
/1.data_process_test&dev.sh:
--------------------------------------------------------------------------------
1 | # Take gemma as example, other models' python code is in ./src/process/prepare/data_process_test_{model}.py
2 | mkdir -p ./data/gemma
3 |
4 | python ./src/process/prepare/data_process_test_gemma.py \
5 | --data_path ./metadata/test.json \
6 | --few_shot 3 \
7 | --save_path ./data/gemma/test.json
8 |
9 |
10 | python ./src/process/prepare/data_process_test_gemma.py \
11 | --data_path ./metadata/dev.json \
12 | --few_shot 3 \
13 | --save_path ./data/gemma/dev.json
14 |
--------------------------------------------------------------------------------
/2.data_process_train.sh:
--------------------------------------------------------------------------------
1 | # need change 4 place
2 | # Please set the wandb key in the python file (e.g ./src/process/prepare/data_process_train_gemma.py)
3 |
4 | mkdir wandb_logs
5 |
6 | experiment_name=Gemma_MixTrain_Data
7 | log_folder="./logs/${experiment_name}"
8 | mkdir -p $log_folder
9 | log_name=$(date +"%m-%d_%H-%M").log
10 |
11 |
12 | python ./src/process/prepare/data_process_train_gemma.py \
13 | --data_path ./metadata/train/mixTrain.json \
14 | --model_path /your/path/to/gemma-2b \
15 | --wandb_log ./wandb_logs \
16 | --experiment_name ${experiment_name} \
17 | --save_path ./data/Gemma/mixTrain > ${log_folder}/$log_name 2>&1 &
--------------------------------------------------------------------------------
/3.single_node_train_gemma.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #python *.py
3 |
4 | # Please set the wandb key in the python file (e.g ./src/sft/train_gemma_resume_val.py)
5 | process_port=29502
6 | experiment_name=Gemma2b_MixTrain_Train
7 | model_dir=/your/path/to/gemma-2b
8 | # ckpt_dir=
9 | train_data_file=./data/gemma/MixTrain
10 | dev_data_file=./data/gemma/dev.json
11 | output_dir=./ckpts
12 | log_folder="./logs/${experiment_name}"
13 | mkdir -p $log_folder
14 | log_name=$(date +"%m-%d_%H-%M").log
15 |
16 | accelerate launch \
17 | --config_file ./src/sft/training_config/zero.yaml \
18 | --num_processes 8 \
19 | --num_machines 1 \
20 | --main_process_port ${process_port} \
21 | --num_cpu_threads_per_process 8 \
22 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \
23 | --model_path ${model_dir} \
24 | --experiment_name ${experiment_name} \
25 | --gradient_accumulation_steps 8 \
26 | --train_data_dir ${train_data_file} \
27 | --dev_data_dir ${dev_data_file} \
28 | --output_dir ${output_dir} \
29 | --log_dir ./wandb_logs \
30 | --n_epochs 1 \
31 | --train_bsz_per_gpu 2 \
32 | --eval_bsz_per_gpu 2 \
33 | --learning_rate 1e-5 \
34 | --eval_step -1 \
35 | --save_step -1 \
36 | --warmup_rates 0.03 \
37 | --max_ckpts 5 \
38 | --gradient_checkpointing > ${log_folder}/$log_name 2>&1 &
39 |
40 |
--------------------------------------------------------------------------------
/4.eval.sh:
--------------------------------------------------------------------------------
1 | experiment_name=Gemma2b_MixTrain_Test
2 | log_folder="./logs/${experiment_name}"
3 | result_folder="./results/${experiment_name}"
4 | mkdir -p $log_folder
5 | mkdir -p $result_folder
6 | log_name=$(date +"%m-%d_%H-%M").log
7 |
8 | CUDA_LAUNCH_BLOCKING=1 accelerate launch --main_process_port 23035 ./src/evaluate/eval_gemma.py \
9 | --model_path=./ckpts/gemma-2b_MixTrain \
10 | --input_path=./data/gemma/test.json \
11 | --output_path=${result_folder}/model_ans.jsonl \
12 | --score_path=${result_folder}/score.json \
13 | --wrong_item_path=${result_folder}/wrong_item.json > ${log_folder}/$log_name 2>&1 &
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multilingual Medicine: Model, Dataset, Benchmark, Code
2 |
3 | Covering English, Chinese, French, Hindi, Spanish, Hindi, Arabic So far
4 |
5 |
6 |
7 | 📃 Paper • 🌐 Demo • 🤗 ApolloCorpus • 🤗 XMedBench • 🌐 ApolloMoE
8 |
中文 | English
9 |
10 |
11 | 
12 |
13 | ## 🌈 Update
14 |
15 | * **[2024.10.15]** [ApolloMoE](https://github.com/FreedomIntelligence/ApolloMoE) repo released, covering 50 Languages.
16 | * **[2024.04.25]** [MedJamba](https://huggingface.co/FreedomIntelligence/Apollo-MedJamba) released, train and evaluation code refer to [repo](https://github.com/FreedomIntelligence/MedJamba).
17 | * **[2024.03.07]** [Paper](https://arxiv.org/abs/2403.03640) released.
18 | * **[2024.02.12]** ApolloCorpus and XMedBench is published!🎉
19 | * **[2024.01.23]** Apollo repo is published!🎉
20 |
21 |
22 | ## Results
23 | 🤗 Apollo-0.5B • 🤗 Apollo-1.8B • 🤗 Apollo-2B • 🤗 Apollo-6B • 🤗 Apollo-7B • 🤗 Apollo-34B • 🤗 Apollo-72B
24 |
25 | 🤗 MedJamba
26 |
27 | 🤗 Apollo-0.5B-GGUF • 🤗 Apollo-2B-GGUF • 🤗 Apollo-6B-GGUF • 🤗 Apollo-7B-GGUF
28 |
29 |
30 |
31 | 
32 |
33 |
34 |
35 | ## Usage Format
36 |
37 | - 0.5B, 1.8B, 2B, 6B, 7B: User:{query}\nAssistant:{response}<|endoftext|>
38 | - 34B, 72B: <|User|>:{query}\n<|Assistant|>:{response}<|endoftext|>
39 |
40 | ## Dataset & Evaluation
41 |
42 | - Dataset
43 | 🤗 ApolloCorpus
44 |
45 | Click to expand
46 |
47 | 
48 |
49 | - [Zip File](https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/blob/main/ApolloCorpus.zip)
50 | - [Data category](https://huggingface.co/datasets/FreedomIntelligence/ApolloCorpus/tree/main/train)
51 | - Pretrain:
52 | - data item:
53 | - json_name: {data_source}_{language}_{data_type}.json
54 | - data_type: medicalBook, medicalGuideline, medicalPaper, medicalWeb(from online forum), medicalWiki
55 | - language: en(English), zh(chinese), es(spanish), fr(french), hi(Hindi)
56 | - data_type: qa(generated qa from text)
57 | - data_type==text: list of string
58 | ```
59 | [
60 | "string1",
61 | "string2",
62 | ...
63 | ]
64 | ```
65 | - data_type==qa: list of qa pairs(list of string)
66 | ```
67 | [
68 | [
69 | "q1",
70 | "a1",
71 | "q2",
72 | "a2",
73 | ...
74 | ],
75 | ...
76 | ]
77 | ```
78 | - SFT:
79 | - json_name: {data_source}_{language}.json
80 | - data_type: code, general, math, medicalExam, medicalPatient
81 | - data item: list of qa pairs(list of string)
82 | ```
83 | [
84 | [
85 | "q1",
86 | "a1",
87 | "q2",
88 | "a2",
89 | ...
90 | ],
91 | ...
92 | ]
93 | ```
94 |
95 |
96 |
97 |
98 | - Evaluation
99 | 🤗 XMedBench
100 |
101 | Click to expand
102 |
103 | - EN:
104 | - [MedQA-USMLE](https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options)
105 | - [MedMCQA](https://huggingface.co/datasets/medmcqa/viewer/default/test)
106 | - [PubMedQA](https://huggingface.co/datasets/pubmed_qa): Because the results fluctuated too much, they were not used in the paper.
107 | - [MMLU-Medical](https://huggingface.co/datasets/cais/mmlu)
108 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine
109 | - ZH:
110 | - [MedQA-MCMLE](https://huggingface.co/datasets/bigbio/med_qa/viewer/med_qa_zh_4options_bigbio_qa/test)
111 | - [CMB-single](https://huggingface.co/datasets/FreedomIntelligence/CMB): Not used in the paper
112 | - Randomly sample 2,000 multiple-choice questions with single answer.
113 | - [CMMLU-Medical](https://huggingface.co/datasets/haonan-li/cmmlu)
114 | - Anatomy, Clinical_knowledge, College_medicine, Genetics, Nutrition, Traditional_chinese_medicine, Virology
115 | - [CExam](https://github.com/williamliujl/CMExam): Not used in the paper
116 | - Randomly sample 2,000 multiple-choice questions
117 |
118 |
119 | - ES: [Head_qa](https://huggingface.co/datasets/head_qa)
120 | - FR: [Frenchmedmcqa](https://github.com/qanastek/FrenchMedMCQA)
121 | - HI: [MMLU_HI](https://huggingface.co/datasets/FreedomIntelligence/MMLU_Arabic)
122 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine
123 | - AR: [MMLU_Ara](https://huggingface.co/datasets/FreedomIntelligence/MMLU_Hindi)
124 | - Clinical knowledge, Medical genetics, Anatomy, Professional medicine, College biology, College medicine
125 |
126 |
127 |
128 |
129 |
130 | ## Results reproduction
131 | Click to expand
132 |
133 |
134 | We take Gemma-2b as example
135 | 1. Download Dataset for project:
136 |
137 | ```
138 | bash 0.download_data.sh
139 | ```
140 |
141 | 2. Prepare test and dev for specific model:
142 |
143 |
144 | - Create test data for with special token, you can use ./util/check.ipynb to check models' special tokens
145 |
146 | ```
147 | bash 1.data_process_test&dev.sh
148 | ```
149 |
150 | 3. Prepare train data for specific model (Create tokenized data in advance):
151 |
152 |
153 | - You can adjust data Training order and Training Epoch in this step
154 |
155 | ```
156 | bash 2.data_process_train.sh
157 | ```
158 |
159 | 4. Train the model
160 |
161 |
162 | - If you want to train in Multi Nodes please refer to ./scripts/multi_node_train_*.sh
163 |
164 |
165 |
166 |
167 | ```
168 | bash 3.single_node_train_gemma.sh
169 | ```
170 |
171 | 5. (Optional) Proxy-Tuning: Directly improve model capabilities without fine-tuning
172 |
173 | ```
174 | bash src/proxy-tuning/scripts/eval/proxy_tuning.sh
175 | ```
176 | 6. Evaluate your model: Generate score for benchmark
177 |
178 | ```
179 | bash 4.eval.sh
180 | ```
181 |
182 | 7. Evaluate your model: Play with your ckpts in bash
183 |
184 | ```
185 | python ./src/evaluate/cli_demo.py --model_name='./ckpts/your/path/tfmr'
186 | ```
187 |
188 |
189 |
190 |
191 | ## Acknowledgment
192 |
193 | - [HuatuoGPT-II](https://github.com/FreedomIntelligence/HuatuoGPT-II)
194 | - [proxy-tuning](https://github.com/alisawuffles/proxy-tuning)
195 |
196 | ## Citation
197 | Please use the following citation if you intend to use our dataset for training or evaluation:
198 |
199 | ```
200 | @misc{wang2024apollo,
201 | title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People},
202 | author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang},
203 | year={2024},
204 | eprint={2403.03640},
205 | archivePrefix={arXiv},
206 | primaryClass={cs.CL}
207 | }
208 | ```
209 |
210 | ## Star History
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
--------------------------------------------------------------------------------
/assets/apollo_medium_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/apollo_medium_final.png
--------------------------------------------------------------------------------
/assets/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/dataset.png
--------------------------------------------------------------------------------
/assets/final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/final.png
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/assets/result.png
--------------------------------------------------------------------------------
/metadata/dev/fr.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question": "Parmi les propositions suivantes, donner celle qui est exacte. En spectrophotométrie d'adsorption moléculaire, une absorbance de 1 correspond à une absorption du faisceau incident par le composé présent dans la cuve de :",
4 | "options": "(A) 100 %.\n(B) 90 %.\n(C) 50 %.\n(D) 40 %.\n(E) 10 %.",
5 | "answer": "(B)",
6 | "source": "frenchmedmcqa"
7 | },
8 | {
9 | "question": "L'exsanguino transfusion est indiqué dans un type d'intoxication aiguë, lequel?",
10 | "options": "(A) Méthanol.\n(B) Cyanures.\n(C) Agents méthémoglobinisants.\n(D) Métaux lourds.\n(E) Monoxyde de carbone.",
11 | "answer": "(C)",
12 | "source": "frenchmedmcqa"
13 | },
14 | {
15 | "question": "Parmi les propositions suivantes, une seule est exacte, Laquelle? Un complexe zinc-EDTA est dosé par spectrophotométrie d'absorption atomique. Sous quelle forme?",
16 | "options": "(A) Atome de zinc.\n(B) Ion zinc.\n(C) Radical.\n(D) Hydroxyde de zinc.\n(E) Oxyde de zinc.",
17 | "answer": "(A)",
18 | "source": "frenchmedmcqa"
19 | },
20 | {
21 | "question": "Parmi les propositions suivantes, laquelle est exacte ? Les hétérosides cardiotoniques exercent un effet inotrope positif:",
22 | "options": "(A) Par une stimulation des récepteurs bêta-adrénergiques cardiaques.\n(B) Par un effet parasympatholytique.\n(C) Par une inhibition de la Na+/K+ ATPase et une modification secondaire des mouvements du calcium.\n(D) Par une inhibition de l'entrée du calcium dans la cellule cardiaque.\n(E) Par un effet vasodilatateur.",
23 | "answer": "(C)",
24 | "source": "frenchmedmcqa"
25 | },
26 | {
27 | "question": "En chromatographie en phase liquide, l'amélioration de la résolution entre 2 composés dépend des paramètres suivants sauf un. Lequel?",
28 | "options": "(A) Le nombre de plateaux théoriques.\n(B) Le facteur de capacité du composé le plus retenu.\n(C) Le facteur de sélectivité.\n(D) La vitesse de déroulement du papier.\n(E) La longueur de la colonne.",
29 | "answer": "(D)",
30 | "source": "frenchmedmcqa"
31 | },
32 | {
33 | "question": "Lors de l'examen direct du LCR, la découverte de diplocoques à Gram négatif doit faire évoquer le germe suivant. Donner la réponse exacte.",
34 | "options": "(A) Streptococcus pneumoniae.\n(B) Escherichia coli.\n(C) Neisseria meningitidis.\n(D) Cryptococcus neoformans.\n(E) Haemophilus influenzae.",
35 | "answer": "(C)",
36 | "source": "frenchmedmcqa"
37 | },
38 | {
39 | "question": "Parmi ces propositions concernant la sérotonine, une seule est fausse, laquelle?",
40 | "options": "(A) C'est une amine biogène endogène.\n(B) Elle franchit la barrière hémato-encéphalique.\n(C) Elle est synthétisée à partir du tryptophane alimentaire.\n(D) Elle n'est pas synthétisée dans les plaquettes.\n(E) Son métabolite principal est constitué par le 5 HlAA.",
41 | "answer": "(B)",
42 | "source": "frenchmedmcqa"
43 | },
44 | {
45 | "question": "Parmi ces propositions, une seule est fausse, laquelle?",
46 | "options": "(A) L'émission de lumière après excitation d'une molécule par radiations s'appelle fluorescence.\n(B) La raie de diffusion Rayleigh est un artefact de la fluorescence moléculaire.\n(C) Un spectrofluorimètre possède deux monochromateurs.\n(D) la phosphorescence est l'émission de lumière de molécules contenant des atomes de phosphore.\n(E) La diffusion Raman est un artefact de la fluorescence moléculaire.",
47 | "answer": "(D)",
48 | "source": "frenchmedmcqa"
49 | },
50 | {
51 | "question": "Parmi les propositions suivantes, indiquer celle qui correspond le mieux à la définition d'une clairance:",
52 | "options": "(A) Volume de plasma contenant une substance donnée, filtré par unité de temps.\n(B) Volume de plasma épuré en une substance donnée par unité de temps.\n(C) Volume d'urine contenant la même quantité d'une substance que 1 mL de plasma.\n(D) Volume de plasma contenant une substance donnée, passant par le rein par unité de temps.\n(E) Volume d'urine d'où une substance peut être réabsorbée par unité de temps.",
53 | "answer": "(B)",
54 | "source": "frenchmedmcqa"
55 | },
56 | {
57 | "question": "Quel(s) élément(s) ne peut (peuvent) pas être dosé(s) par photométrie de flamme ?",
58 | "options": "(A) Aluminium.\n(B) Sodium.\n(C) Potassium.\n(D) Césium.\n(E) Lithium.",
59 | "answer": "(A)",
60 | "source": "frenchmedmcqa"
61 | },
62 | {
63 | "question": "Parmi les propositions suivantes, quelle est celle qui s'applique à Clostridium perfringens?",
64 | "options": "(A) C'est une bactérie qui peut former des spores.\n(B) C'est une bactérie aéro-anaérobie facultative.\n(C) Il n'est jamais responsable de toxi-infections alimentaires.\n(D) C'est un agent fréquent d'infections urinaires.\n(E) Il est toujours sensible aux aminosides.",
65 | "answer": "(A)",
66 | "source": "frenchmedmcqa"
67 | },
68 | {
69 | "question": "En 1998, sur 85.000 hommes de 60 à 64 ans, 300 cas d'infarctus du myocarde ont été dénombrés dont 270 nouveaux. Parmi les nouveaux cas, 135 sont décédés dans l'année. Le nombre total de décès dus à cette maladie pour l'année est de 150; Parmi les résultats ci-dessous, indiquer celui qui représente le taux d'incidence:",
70 | "options": "(A) 135/85.000.\n(B) 300/85.000.\n(C) 150/85.000.\n(D) 270 /85.000.\n(E) 135/270.",
71 | "answer": "(D)",
72 | "source": "frenchmedmcqa"
73 | },
74 | {
75 | "question": "Parmi les unités de radioactivité suivantes, indiquer celle qui est utilisée pour exprimer l'équivalent de dose (H) :",
76 | "options": "(A) Becquerel.\n(B) Curie.\n(C) Gray.\n(D) Sievert.\n(E) Rad.",
77 | "answer": "(D)",
78 | "source": "frenchmedmcqa"
79 | },
80 | {
81 | "question": "Parmi les propositions suivantes concernant les phénomènes pouvant entraîner une inhibition de la fluorescence. Laquelle est fausse?",
82 | "options": "(A) La variation du pH peut modifier l'intensité de la fluorescence.\n(B) L'oxygène dissous inhibe la fluorescence.\n(C) Si la température augmente, l'intensité de la fluorescence augmente.\n(D) Le solvant peut inhiber la fluorescence.\n(E) La présence d'impuretés dans la solution peut provoquer une inhibition de la fluorescence.",
83 | "answer": "(C)",
84 | "source": "frenchmedmcqa"
85 | },
86 | {
87 | "question": "Parmi les propositions suivantes, une seule est exacte, laquelle? La caractérisation de Cryptococcus neoformans dans un LCR est réalisée par la présence à l'examen direct",
88 | "options": "(A) De blastospores.\n(B) D'une capsule polysaccharidique.\n(C) D'arthrospores.\n(D) De vrai mycélium.\n(E) De pseudomycelium.",
89 | "answer": "(B)",
90 | "source": "frenchmedmcqa"
91 | },
92 | {
93 | "question": "Parmi les propositions suivantes, indiquer celle qui est juste. Chez un sujet intoxiqué par un agent méthémoglobinisant, le traitement par le bleu de méthylène IV sera inefficace si le malade:",
94 | "options": "(A) Est à jeun.\n(B) Est déficitaire en glucose-6-phosphate déshydrogénase.\n(C) Est dans le coma.\n(D) Est déficitaire en cytochrome c réductase.\n(E) A absorbé des psychotropes.",
95 | "answer": "(B)",
96 | "source": "frenchmedmcqa"
97 | }
98 | ]
--------------------------------------------------------------------------------
/metadata/merge_json_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def merge_json_files(root_path):
5 | merged_data = {}
6 | for file_name in os.listdir(root_path):
7 |
8 | if file_name.endswith(".json"):
9 | file_path = os.path.join(root_path, file_name)
10 | print(file_path)
11 | # Read the content of the JSON file
12 | with open(file_path, "r", encoding="utf-8") as json_file:
13 | json_data = json.load(json_file)
14 |
15 | # Add the data to the merged dictionary using the file name as the key
16 | merged_data[file_name[:-5]] = json_data
17 |
18 | # Path for the merged file
19 | merged_file_path = "./mixTrain.json"
20 |
21 | # Write the merged JSON data to a new file
22 | with open(merged_file_path, "w", encoding="utf-8") as merged_file:
23 | json.dump(merged_data, merged_file, ensure_ascii=False, indent=2)
24 |
25 | print(f"Merged file has been saved as: {merged_file_path}")
26 |
27 | # Replace with the actual directory path
28 | merge_json_files("./train/mixTrain")
29 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | llama-index-llms-vllm
--------------------------------------------------------------------------------
/scripts/3.multinode_train_gema7B_rank0.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #python *.py
3 |
4 | node_rank=0
5 | master_ip=yourip
6 |
7 |
8 | cd ./src/sft
9 | process_port=29503
10 | experiment_name=Gemma7B_MixTrain
11 | model_dir=/your/path/to/Gemma-7b
12 | # ckpt_dir=./ckpts/
13 | train_data_file=./data/gemma/MixTrain
14 | dev_data_file=./data/gemma/dev.json
15 | output_dir=./ckpts
16 | log_folder="./logs/${experiment_name}"
17 | mkdir -p $log_folder
18 | log_name=$(date +"%m-%d_%H-%M").log
19 |
20 |
21 |
22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \
23 | --config_file ./src/sft/training_config/zero_multi.yaml \
24 | --num_processes 24 \
25 | --num_machines 3 \
26 | --machine_rank ${node_rank} \
27 | --main_process_ip "${master_ip}" \
28 | --main_process_port ${process_port} \
29 | --num_cpu_threads_per_process 8 \
30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \
31 | --model_path ${model_dir} \
32 | --experiment_name ${experiment_name} \
33 | --gradient_accumulation_steps 8 \
34 | --train_data_dir ${train_data_file} \
35 | --dev_data_dir ${dev_data_file} \
36 | --output_dir ${output_dir} \
37 | --log_dir ./wandb_logs \
38 | --n_epochs 1 \
39 | --train_bsz_per_gpu 2 \
40 | --eval_bsz_per_gpu 2 \
41 | --learning_rate 1e-5 \
42 | --eval_step -1 \
43 | --save_step -1 \
44 | --warmup_rates 0.03 \
45 | --max_ckpts 3 \
46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 &
--------------------------------------------------------------------------------
/scripts/3.multinode_train_gema7B_rank1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #python *.py
3 |
4 | node_rank=1
5 | master_ip=yourip
6 |
7 |
8 | cd ./src/sft
9 | process_port=29503
10 | experiment_name=Gemma7B_MixTrain
11 | model_dir=/your/path/to/Gemma-7b
12 | # ckpt_dir=./ckpts/
13 | train_data_file=./data/gemma/MixTrain
14 | dev_data_file=./data/gemma/dev.json
15 | output_dir=./ckpts
16 | log_folder="./logs/${experiment_name}"
17 | mkdir -p $log_folder
18 | log_name=$(date +"%m-%d_%H-%M").log
19 |
20 |
21 |
22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \
23 | --config_file ./src/sft/training_config/zero_multi.yaml \
24 | --num_processes 24 \
25 | --num_machines 3 \
26 | --machine_rank ${node_rank} \
27 | --main_process_ip "${master_ip}" \
28 | --main_process_port ${process_port} \
29 | --num_cpu_threads_per_process 8 \
30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \
31 | --model_path ${model_dir} \
32 | --experiment_name ${experiment_name} \
33 | --gradient_accumulation_steps 8 \
34 | --train_data_dir ${train_data_file} \
35 | --dev_data_dir ${dev_data_file} \
36 | --output_dir ${output_dir} \
37 | --log_dir ./wandb_logs \
38 | --n_epochs 1 \
39 | --train_bsz_per_gpu 2 \
40 | --eval_bsz_per_gpu 2 \
41 | --learning_rate 1e-5 \
42 | --eval_step -1 \
43 | --save_step -1 \
44 | --warmup_rates 0.03 \
45 | --max_ckpts 3 \
46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 &
--------------------------------------------------------------------------------
/scripts/3.multinode_train_gema7B_rank2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #python *.py
3 |
4 | node_rank=2
5 | master_ip=yourip
6 |
7 |
8 | cd ./src/sft
9 | process_port=29503
10 | experiment_name=Gemma7B_MixTrain
11 | model_dir=/your/path/to/Gemma-7b
12 | # ckpt_dir=./ckpts/
13 | train_data_file=./data/gemma/MixTrain
14 | dev_data_file=./data/gemma/dev.json
15 | output_dir=./ckpts
16 | log_folder="./logs/${experiment_name}"
17 | mkdir -p $log_folder
18 | log_name=$(date +"%m-%d_%H-%M").log
19 |
20 |
21 |
22 | CUDA_LAUNCH_BLOCKING=1 accelerate launch \
23 | --config_file ./src/sft/training_config/zero_multi.yaml \
24 | --num_processes 24 \
25 | --num_machines 3 \
26 | --machine_rank ${node_rank} \
27 | --main_process_ip "${master_ip}" \
28 | --main_process_port ${process_port} \
29 | --num_cpu_threads_per_process 8 \
30 | --deepspeed_multinode_launcher standard ./src/sft/train_gemma_resume_val.py \
31 | --model_path ${model_dir} \
32 | --experiment_name ${experiment_name} \
33 | --gradient_accumulation_steps 8 \
34 | --train_data_dir ${train_data_file} \
35 | --dev_data_dir ${dev_data_file} \
36 | --output_dir ${output_dir} \
37 | --log_dir ./wandb_logs \
38 | --n_epochs 1 \
39 | --train_bsz_per_gpu 2 \
40 | --eval_bsz_per_gpu 2 \
41 | --learning_rate 1e-5 \
42 | --eval_step -1 \
43 | --save_step -1 \
44 | --warmup_rates 0.03 \
45 | --max_ckpts 3 \
46 | --gradient_checkpointing > ${log_folder}/rank${node_rank}.log 2>&1 &
--------------------------------------------------------------------------------
/src/evaluate/cli_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import torch
4 | from threading import Thread
5 | from transformers import AutoTokenizer
6 | from transformers import AutoModelForCausalLM
7 | import argparse
8 | from transformers import TextIteratorStreamer
9 |
10 | def load_model(model_name):
11 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left', pad_token='<|extra_0|>', eos_token='<|endoftext|>')
12 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype='auto', trust_remote_code=True)
13 | return model, tokenizer
14 |
15 | def generate_prompt(query, history):
16 | if not history:
17 | return f"User:{query}\nAssistant:"
18 | else:
19 | prompt = ''
20 | for i, (old_query, response) in enumerate(history):
21 | prompt += "User:{}\nAssistant:{}\n".format(old_query, response)
22 | prompt += "User:{}\nAssistant:".format(query)
23 | return prompt
24 |
25 | def remove_overlap(str1, str2):
26 | for i in range(len(str1), -1, -1):
27 | if str1.endswith(str2[:i]):
28 | return str2[i:]
29 | return str2
30 |
31 | def main(args):
32 | model, tokenizer = load_model(args.model_name)
33 | sep = tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id)
34 | print(sep)
35 |
36 | model = model.eval()
37 |
38 | gen_kwargs = {'max_new_tokens': 1024, 'do_sample':True, 'top_p':0.7, 'temperature':0.3, 'repetition_penalty':1.1}
39 |
40 | os_name = platform.system()
41 | clear_command = 'cls' if os_name == 'Windows' else 'clear'
42 | history = []
43 | print("Model: Hello, I am a large model that answers medical and health questions. It is currently in the testing stage. Please follow your doctor's advice. How can I help you? Enter clear to clear the conversation history, stop to terminate the program")
44 | while True:
45 | query = input("\nUser:")
46 | if query == "stop":
47 | break
48 | if query == "clear":
49 | history = []
50 | os.system(clear_command)
51 | continue
52 |
53 | print(f"Model:", end="", flush=True)
54 |
55 |
56 | prompt = generate_prompt(query, history)
57 | inputs = tokenizer([prompt], return_tensors="pt")
58 | inputs = inputs.to(model.device)
59 |
60 | streamer = TextIteratorStreamer(tokenizer,skip_prompt=True)
61 | generation_kwargs = dict(input_ids=inputs['input_ids'], streamer=streamer, **gen_kwargs)
62 |
63 | thread = Thread(target=model.generate, kwargs=generation_kwargs)
64 | thread.start()
65 |
66 | generated_text = ''
67 |
68 | for new_text in streamer:
69 | if sep in new_text:
70 | new_text = remove_overlap(generated_text,new_text[:-len(sep)])
71 | for char in new_text:
72 | generated_text += char
73 | print(char,end='',flush = True)
74 | break
75 | for char in new_text:
76 | generated_text += char
77 | print(char,end='',flush = True)
78 | history = history + [(query, generated_text)]
79 |
80 |
81 | if __name__ == "__main__":
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("--model_name", type=str, default="./ckpts/your/path/tfmr")
84 | args = parser.parse_args()
85 | main(args)
86 |
--------------------------------------------------------------------------------
/src/evaluate/eval_72b_34b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4 | import re
5 | import argparse
6 | import json
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | import torch.distributed as dist
10 | from collections import defaultdict
11 | from llama_index.llms.vllm import Vllm
12 |
13 |
14 | def get_answer(data:str,llm):
15 | response=llm.complete(
16 | data
17 | )
18 | return response.text
19 |
20 |
21 | def extract_and_choose_answer(pattern, model_answer):
22 | # if '\n' in model_answer:
23 | # model_answer_split = model_answer.split('\n')
24 | # for model_answer_i in model_answer_split:
25 | # if len(model_answer_i):
26 | # model_answer = model_answer_i
27 | # break
28 | matches = re.findall(pattern, model_answer)
29 | option_count = {}
30 | for match in matches:
31 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
32 |
33 | if not option_count:
34 | # else use loose pattern
35 | loose_pattern = r'[A-F]'
36 | if pattern == loose_pattern:
37 | if model_answer == 'Yes.':
38 | return 'A'
39 | elif model_answer == 'No.':
40 | return 'B'
41 | else:
42 | return None
43 | else:
44 | return extract_and_choose_answer(loose_pattern, model_answer)
45 |
46 | max_count = max(option_count.values())
47 | max_options = [option for option, count in option_count.items() if count == max_count]
48 | return max_options[0]
49 |
50 |
51 | def generate_score(result_path, score_path, wrong_item_path):
52 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
53 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
54 |
55 | all = defaultdict(int)
56 | right = defaultdict(int)
57 | accuracy_dict = defaultdict(int)
58 | wrong_item = []
59 |
60 | print(f'****Total:{len(json_objects)}****')
61 | debug = True
62 | for item in json_objects:
63 | source = item["source"]
64 | for answer in item["model_answer"]:
65 | all[source] += 1
66 | pattern = r'(\(]([A-Fa-f])[)\)'
67 | extract_answer = extract_and_choose_answer(pattern, answer)
68 | item['extract_answer'] = extract_answer
69 | if debug:
70 | debug = False
71 | print(f'extract_answer:{extract_answer}')
72 | right_answer = item['answer']
73 | print(f'right_answer:{right_answer}')
74 | if item['answer'] == extract_answer:
75 | right[source] += 1
76 | else:
77 | wrong_item.append(item)
78 |
79 |
80 | print(f'all:{all}')
81 | print(f'right:{right}')
82 |
83 | for key in right:
84 | accuracy_dict[key] = right[key] / all[key]
85 |
86 | with open(score_path, "w", encoding="utf8") as f:
87 | json.dump(accuracy_dict, f, indent=4)
88 |
89 | print(f'***********score_result save in {score_path}*************')
90 |
91 | with open(wrong_item_path, "w", encoding="utf8") as f:
92 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
93 |
94 | print(f'***********wrong_item save in {wrong_item_path}*************')
95 |
96 |
97 | def generate_response(args,llm):
98 |
99 |
100 | # model_path = args.model_path
101 |
102 | fp = open(args.output_path,'w')
103 |
104 | with open(args.input_path) as f:
105 | data = json.load(f)
106 |
107 | for item in tqdm(data):
108 | question=item['question']
109 | answer=get_answer(question,llm)
110 | item['model_answer']=answer
111 | fp.write(json.dumps(item, ensure_ascii=False) +'\n')
112 | fp.flush()
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
118 | parser.add_argument("--input_path", type=str, help="path to the input data")
119 | parser.add_argument("--output_path", type=str, help="path to the output data")
120 | parser.add_argument("--score_path", type=str, help="path to the score")
121 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
122 | args = parser.parse_args()
123 | llm = Vllm(
124 | model=args.model_path,
125 | trust_remote_code=True,
126 | max_new_tokens=64,
127 | temperature=0,
128 | dtype="bfloat16",
129 | tensor_parallel_size=8,
130 | vllm_kwargs={"swap_space": 1},
131 | )
132 | generate_response(args,llm)
133 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
134 |
135 |
136 | '''
137 |
138 | accelerate launch /mntcephfs/data/med/xidong/Medbase/src/evaluate/eval_qwen.py \
139 | --model_path=/mntcephfs/data/med/xidong/checkpoints/Qwen-1_8B \
140 | --input_path=/mntcephfs/data/med/xidong/Medbase/data/Qwen-1.8B/test.json \
141 | --output_path=/mntcephfs/data/med/xidong/Medbase/result/Qwen-1.8B/model_ans.jsonl \
142 | --score_path=/mntcephfs/data/med/xidong/Medbase/result/Qwen-1.8B/score.json \
143 | --batch_size=8 > ${log_folder}/$log_name 2>&1 &
144 |
145 | '''
146 |
147 |
148 |
--------------------------------------------------------------------------------
/src/evaluate/eval_gemma.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 |
103 | print(f'****Total:{len(json_objects)}****')
104 | debug = True
105 | for item in json_objects:
106 | source = item["source"]
107 | for answer in item["model_answer"]:
108 | all[source] += 1
109 | pattern = r'[(\(]([A-Fa-f])[)\)]'
110 | extract_answer = extract_and_choose_answer(pattern, answer)
111 | if debug:
112 | debug = False
113 | print(f'extract_answer:{extract_answer}')
114 | right_answer = item['answer']
115 | print(f'right_answer:{right_answer}')
116 | if item['answer'] == extract_answer:
117 | right[source] += 1
118 |
119 |
120 | print(f'all:{all}')
121 | print(f'right:{right}')
122 |
123 | for key in right:
124 | accuracy_dict[key] = right[key] / all[key]
125 |
126 | with open(score_path, "w", encoding="utf8") as f:
127 | json.dump(accuracy_dict, f, indent=4)
128 |
129 | print(f'***********score_result save in {score_path}*************')
130 |
131 |
132 | def generate_response(args):
133 | accelerator = Accelerator()
134 |
135 | model_path = args.model_path
136 | accelerator.print(f'****************model_path:{model_path}******************')
137 |
138 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left', pad_token='', eos_token='')
139 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, ignore_mismatched_sizes=True).half()
140 | # generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=tokenizer.pad_token_id, num_return_sequences=args.num_return, max_new_tokens=256, min_new_tokens=2, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
141 | model = model.half().cuda()
142 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
143 |
144 |
145 | dataset = TestDataset(args.input_path, tokenizer)
146 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
147 |
148 | model = model.eval()
149 | if dist.is_initialized():
150 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
151 |
152 | dataloader = accelerator.prepare(dataloader)
153 | accelerator.print(f'******************load_model from {model_path}******************')
154 |
155 | if accelerator.is_main_process:
156 | fp = open(args.output_path,'w')
157 |
158 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
159 | for batch in dataloader_iterator:
160 | batch_input_ids = batch["input_ids"]
161 | batch_data = batch["data"]
162 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
163 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
164 |
165 |
166 | if dist.is_initialized():
167 | all_batch_data = [None] * dist.get_world_size()
168 | all_batch_responses = [None] * dist.get_world_size()
169 | dist.all_gather_object(all_batch_responses, batch_responses)
170 | dist.all_gather_object(all_batch_data, batch_data)
171 | else:
172 | all_batch_data = [batch_data, ]
173 | all_batch_responses = [batch_responses, ]
174 |
175 | all_data = [item for sublist in all_batch_data for item in sublist]
176 | all_response = [item for sublist in all_batch_responses for item in sublist]
177 |
178 | for data, responses in zip(all_data, all_response):
179 | answer_list = []
180 | for response in responses:
181 | answer_list.append(response)
182 | data['model_answer'] = answer_list
183 | if accelerator.is_main_process:
184 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
185 | fp.flush()
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
190 | parser.add_argument("--input_path", type=str, help="path to the input data")
191 | parser.add_argument("--output_path", type=str, help="path to the output data")
192 | parser.add_argument("--score_path", type=str, help="path to the score")
193 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
194 | parser.add_argument("--num_return", type=int, help="number of return sequences")
195 | parser.add_argument("--batch_size", type=int, help="batch size")
196 | args = parser.parse_args()
197 | generate_response(args)
198 | generate_score(args.output_path, args.score_path)
199 |
--------------------------------------------------------------------------------
/src/evaluate/eval_huatuo2.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 | wrong_item = []
103 |
104 | print(f'****Total:{len(json_objects)}****')
105 | debug = True
106 | for item in json_objects:
107 | source = item["source"]
108 | for answer in item["model_answer"]:
109 | all[source] += 1
110 | pattern = r'[(\(]([A-Fa-f])[)\)]'
111 | extract_answer = extract_and_choose_answer(pattern, answer)
112 | item['extract_answer'] = extract_answer
113 | if debug:
114 | debug = False
115 | print(f'extract_answer:{extract_answer}')
116 | right_answer = item['answer']
117 | print(f'right_answer:{right_answer}')
118 | if item['answer'] == extract_answer:
119 | right[source] += 1
120 | else:
121 | wrong_item.append(item)
122 |
123 |
124 | print(f'all:{all}')
125 | print(f'right:{right}')
126 |
127 | for key in right:
128 | accuracy_dict[key] = right[key] / all[key]
129 |
130 | with open(score_path, "w", encoding="utf8") as f:
131 | json.dump(accuracy_dict, f, indent=4)
132 |
133 | print(f'***********score_result save in {score_path}*************')
134 |
135 | with open(wrong_item_path, "w", encoding="utf8") as f:
136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
137 |
138 | print(f'***********wrong_item save in {wrong_item_path}*************')
139 |
140 |
141 | def generate_response(args):
142 | accelerator = Accelerator()
143 |
144 | model_path = args.model_path
145 | accelerator.print(f'****************model_path:{model_path}******************')
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',trust_remote_code=True)
148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
149 | model = model.half().cuda()
150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
151 |
152 |
153 | dataset = TestDataset(args.input_path, tokenizer)
154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
155 |
156 | model = model.eval()
157 | if dist.is_initialized():
158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
159 |
160 | dataloader = accelerator.prepare(dataloader)
161 | accelerator.print(f'******************load_model from {model_path}******************')
162 |
163 | if accelerator.is_main_process:
164 | fp = open(args.output_path,'w')
165 |
166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
167 | for batch in dataloader_iterator:
168 | batch_input_ids = batch["input_ids"]
169 | batch_data = batch["data"]
170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
172 |
173 |
174 | if dist.is_initialized():
175 | all_batch_data = [None] * dist.get_world_size()
176 | all_batch_responses = [None] * dist.get_world_size()
177 | dist.all_gather_object(all_batch_responses, batch_responses)
178 | dist.all_gather_object(all_batch_data, batch_data)
179 | else:
180 | all_batch_data = [batch_data, ]
181 | all_batch_responses = [batch_responses, ]
182 |
183 | all_data = [item for sublist in all_batch_data for item in sublist]
184 | all_response = [item for sublist in all_batch_responses for item in sublist]
185 |
186 | for data, responses in zip(all_data, all_response):
187 | answer_list = []
188 | for response in responses:
189 | answer_list.append(response)
190 | data['model_answer'] = answer_list
191 | if accelerator.is_main_process:
192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
193 | fp.flush()
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
198 | parser.add_argument("--input_path", type=str, help="path to the input data")
199 | parser.add_argument("--output_path", type=str, help="path to the output data")
200 | parser.add_argument("--score_path", type=str, help="path to the score")
201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
202 | parser.add_argument("--num_return", type=int, help="number of return sequences")
203 | parser.add_argument("--batch_size", type=int, help="batch size")
204 | args = parser.parse_args()
205 | # generate_response(args)
206 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
207 |
208 |
--------------------------------------------------------------------------------
/src/evaluate/eval_llama2.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 | wrong_item = []
103 |
104 | print(f'****Total:{len(json_objects)}****')
105 | debug = True
106 | for item in json_objects:
107 | source = item["source"]
108 | for answer in item["model_answer"]:
109 | all[source] += 1
110 | pattern = r'[(\(]([A-Fa-f])[)\)]'
111 | extract_answer = extract_and_choose_answer(pattern, answer)
112 | item['extract_answer'] = extract_answer
113 | if debug:
114 | debug = False
115 | print(f'extract_answer:{extract_answer}')
116 | right_answer = item['answer']
117 | print(f'right_answer:{right_answer}')
118 | if item['answer'] == extract_answer:
119 | right[source] += 1
120 | else:
121 | wrong_item.append(item)
122 |
123 |
124 | print(f'all:{all}')
125 | print(f'right:{right}')
126 |
127 | for key in right:
128 | accuracy_dict[key] = right[key] / all[key]
129 |
130 | with open(score_path, "w", encoding="utf8") as f:
131 | json.dump(accuracy_dict, f, indent=4)
132 |
133 | print(f'***********score_result save in {score_path}*************')
134 |
135 | with open(wrong_item_path, "w", encoding="utf8") as f:
136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
137 |
138 | print(f'***********wrong_item save in {wrong_item_path}*************')
139 |
140 |
141 | def generate_response(args):
142 | accelerator = Accelerator()
143 |
144 | model_path = args.model_path
145 | accelerator.print(f'****************model_path:{model_path}******************')
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='')
148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
149 | model = model.half().cuda()
150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
151 |
152 |
153 | dataset = TestDataset(args.input_path, tokenizer)
154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
155 |
156 | model = model.eval()
157 | if dist.is_initialized():
158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
159 |
160 | dataloader = accelerator.prepare(dataloader)
161 | accelerator.print(f'******************load_model from {model_path}******************')
162 |
163 | if accelerator.is_main_process:
164 | fp = open(args.output_path,'w')
165 |
166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
167 | for batch in dataloader_iterator:
168 | batch_input_ids = batch["input_ids"]
169 | batch_data = batch["data"]
170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
172 |
173 |
174 | if dist.is_initialized():
175 | all_batch_data = [None] * dist.get_world_size()
176 | all_batch_responses = [None] * dist.get_world_size()
177 | dist.all_gather_object(all_batch_responses, batch_responses)
178 | dist.all_gather_object(all_batch_data, batch_data)
179 | else:
180 | all_batch_data = [batch_data, ]
181 | all_batch_responses = [batch_responses, ]
182 |
183 | all_data = [item for sublist in all_batch_data for item in sublist]
184 | all_response = [item for sublist in all_batch_responses for item in sublist]
185 |
186 | for data, responses in zip(all_data, all_response):
187 | answer_list = []
188 | for response in responses:
189 | answer_list.append(response)
190 | data['model_answer'] = answer_list
191 | if accelerator.is_main_process:
192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
193 | fp.flush()
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
198 | parser.add_argument("--input_path", type=str, help="path to the input data")
199 | parser.add_argument("--output_path", type=str, help="path to the output data")
200 | parser.add_argument("--score_path", type=str, help="path to the score")
201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
202 | parser.add_argument("--num_return", type=int, help="number of return sequences")
203 | parser.add_argument("--batch_size", type=int, help="batch size")
204 | args = parser.parse_args()
205 | generate_response(args)
206 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
207 |
208 |
--------------------------------------------------------------------------------
/src/evaluate/eval_llama70b.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | import json
8 | from tqdm import tqdm
9 | from torch.utils.data import DataLoader
10 | import torch.distributed as dist
11 | from collections import defaultdict
12 | from llama_index.llms.vllm import Vllm
13 |
14 | llm = Vllm(
15 | model='./Llama-2-70b-hf/',
16 | trust_remote_code=True,
17 | max_new_tokens=64,
18 | temperature=0,
19 | dtype="bfloat16",
20 | tensor_parallel_size=8,
21 | vllm_kwargs={"swap_space": 1},
22 | )
23 |
24 | def get_answer(data:str):
25 | response=llm.complete(
26 | data
27 | )
28 | return response.text
29 |
30 |
31 | def extract_and_choose_answer(pattern, model_answer):
32 | if '\n' in model_answer:
33 | model_answer_split = model_answer.split('\n')
34 | for model_answer_i in model_answer_split:
35 | if len(model_answer_i):
36 | model_answer = model_answer_i
37 | break
38 | matches = re.findall(pattern, model_answer)
39 | option_count = {}
40 | for match in matches:
41 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
42 |
43 | if not option_count:
44 | # else use loose pattern
45 | loose_pattern = r'[A-F]'
46 | if pattern == loose_pattern:
47 | if model_answer == 'Yes.':
48 | return 'A'
49 | elif model_answer == 'No.':
50 | return 'B'
51 | else:
52 | return None
53 | else:
54 | return extract_and_choose_answer(loose_pattern, model_answer)
55 |
56 | max_count = max(option_count.values())
57 | max_options = [option for option, count in option_count.items() if count == max_count]
58 | return max_options[0]
59 |
60 |
61 |
62 | def generate_score(result_path, score_path, wrong_item_path):
63 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
64 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
65 |
66 | all = defaultdict(int)
67 | right = defaultdict(int)
68 | accuracy_dict = defaultdict(int)
69 | wrong_item = []
70 |
71 | print(f'****Total:{len(json_objects)}****')
72 | debug = True
73 | for item in json_objects:
74 | source = item["source"]
75 | for answer in item["model_answer"]:
76 | all[source] += 1
77 | pattern = r'(\(]([A-Fa-f])[)\)'
78 | extract_answer = extract_and_choose_answer(pattern, answer)
79 | item['extract_answer'] = extract_answer
80 | if debug:
81 | debug = False
82 | print(f'extract_answer:{extract_answer}')
83 | right_answer = item['answer']
84 | print(f'right_answer:{right_answer}')
85 | if item['answer'] == extract_answer:
86 | right[source] += 1
87 | else:
88 | wrong_item.append(item)
89 |
90 |
91 | print(f'all:{all}')
92 | print(f'right:{right}')
93 |
94 | for key in right:
95 | accuracy_dict[key] = right[key] / all[key]
96 |
97 | with open(score_path, "w", encoding="utf8") as f:
98 | json.dump(accuracy_dict, f, indent=4)
99 |
100 | print(f'***********score_result save in {score_path}*************')
101 |
102 | with open(wrong_item_path, "w", encoding="utf8") as f:
103 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
104 |
105 | print(f'***********wrong_item save in {wrong_item_path}*************')
106 |
107 |
108 | def generate_response(args):
109 |
110 |
111 | model_path = args.model_path
112 |
113 | fp = open(args.output_path,'w')
114 |
115 | with open(args.input_path) as f:
116 | data = json.load(f)
117 |
118 | for item in tqdm(data):
119 | question=item['question']
120 | answer=get_answer(question)
121 | item['model_answer']=answer
122 | fp.write(json.dumps(item, ensure_ascii=False) +'\n')
123 | fp.flush()
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
129 | parser.add_argument("--input_path", type=str, help="path to the input data")
130 | parser.add_argument("--output_path", type=str, help="path to the output data")
131 | parser.add_argument("--score_path", type=str, help="path to the score")
132 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
133 | args = parser.parse_args()
134 | generate_response(args)
135 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
136 |
137 |
138 |
--------------------------------------------------------------------------------
/src/evaluate/eval_meditron.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4 | import re
5 | import argparse
6 | from accelerate import Accelerator
7 | import json
8 | from tqdm import tqdm
9 | from torch.utils.data import DataLoader
10 | import torch.distributed as dist
11 | from collections import defaultdict
12 |
13 |
14 | class TestDataset(torch.utils.data.Dataset):
15 | def __init__(self, data_path,tokenizer):
16 | self.data = []
17 | with open(data_path) as f:
18 | self.data = json.load(f)
19 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
20 | if dist_flag_0:
21 | print(f'load {len(self.data)} data from {data_path}')
22 | self.tokenizer = tokenizer
23 | self.debug = True
24 |
25 | def __getitem__(self, index):
26 | item = self.data[index]
27 | return {
28 | 'data': item,
29 | 'input': item['question']
30 | }
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 | def collate_fn(self, batch):
36 | batch_query = [x['input'] for x in batch]
37 | batch_data = [x['data'] for x in batch]
38 | out_batch = {}
39 | out_batch['data'] = batch_data
40 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
41 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
42 | if self.debug and dist_flag_0:
43 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
44 | for idx, sample in enumerate(decoded_texts):
45 | print(f'*******************batch_texts[{idx}]**********************************')
46 | print(sample)
47 | self.debug = False
48 | return out_batch
49 |
50 |
51 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
52 | responses_list=[]
53 | batch_return=[]
54 | input_len = len(batch_input_ids[0])
55 | for idx, output_ids in enumerate(batch_output_ids):
56 | generated_ids = output_ids[input_len:]
57 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
58 | if idx % num_return == num_return-1:
59 | responses_list.append(batch_return)
60 | batch_return=[]
61 | return responses_list
62 |
63 | def extract_and_choose_answer(pattern, model_answer):
64 | if '\n' in model_answer:
65 | model_answer_split = model_answer.split('\n')
66 | for model_answer_i in model_answer_split:
67 | if len(model_answer_i):
68 | model_answer = model_answer_i
69 | break
70 | matches = re.findall(pattern, model_answer)
71 | option_count = {}
72 | for match in matches:
73 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
74 |
75 | if not option_count:
76 | # else use loose pattern
77 | loose_pattern = r'[A-F]'
78 | if pattern == loose_pattern:
79 | if model_answer == 'Yes.':
80 | return 'A'
81 | elif model_answer == 'No.':
82 | return 'B'
83 | else:
84 | return None
85 | else:
86 | return extract_and_choose_answer(loose_pattern, model_answer)
87 |
88 | max_count = max(option_count.values())
89 | max_options = [option for option, count in option_count.items() if count == max_count]
90 | return max_options[0]
91 |
92 |
93 |
94 | def generate_score(result_path, score_path, wrong_item_path):
95 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
96 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
97 |
98 | all = defaultdict(int)
99 | right = defaultdict(int)
100 | accuracy_dict = defaultdict(int)
101 | wrong_item = []
102 |
103 | print(f'****Total:{len(json_objects)}****')
104 | debug = True
105 | for item in json_objects:
106 | source = item["source"]
107 | for answer in item["model_answer"]:
108 | all[source] += 1
109 | pattern = r'[(\(]([A-Fa-f])[)\)]'
110 | extract_answer = extract_and_choose_answer(pattern, answer)
111 | item['extract_answer'] = extract_answer
112 | if debug:
113 | debug = False
114 | print(f'extract_answer:{extract_answer}')
115 | right_answer = item['answer']
116 | print(f'right_answer:{right_answer}')
117 | if item['answer'] == extract_answer:
118 | right[source] += 1
119 | else:
120 | wrong_item.append(item)
121 |
122 |
123 | print(f'all:{all}')
124 | print(f'right:{right}')
125 |
126 | for key in right:
127 | accuracy_dict[key] = right[key] / all[key]
128 |
129 | with open(score_path, "w", encoding="utf8") as f:
130 | json.dump(accuracy_dict, f, indent=4)
131 |
132 | print(f'***********score_result save in {score_path}*************')
133 |
134 | with open(wrong_item_path, "w", encoding="utf8") as f:
135 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
136 |
137 | print(f'***********wrong_item save in {wrong_item_path}*************')
138 |
139 |
140 | def generate_response(args):
141 | accelerator = Accelerator()
142 |
143 | model_path = args.model_path
144 | accelerator.print(f'****************model_path:{model_path}******************')
145 |
146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', pad_token="", eos_token="")
147 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
148 | model = model.half().cuda()
149 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
150 |
151 |
152 | dataset = TestDataset(args.input_path, tokenizer)
153 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
154 |
155 | model = model.eval()
156 | if dist.is_initialized():
157 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
158 |
159 | dataloader = accelerator.prepare(dataloader)
160 | accelerator.print(f'******************load_model from {model_path}******************')
161 |
162 | if accelerator.is_main_process:
163 | fp = open(args.output_path,'w')
164 |
165 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
166 | for batch in dataloader_iterator:
167 | batch_input_ids = batch["input_ids"]
168 | batch_data = batch["data"]
169 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
170 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
171 |
172 |
173 | if dist.is_initialized():
174 | all_batch_data = [None] * dist.get_world_size()
175 | all_batch_responses = [None] * dist.get_world_size()
176 | dist.all_gather_object(all_batch_responses, batch_responses)
177 | dist.all_gather_object(all_batch_data, batch_data)
178 | else:
179 | all_batch_data = [batch_data, ]
180 | all_batch_responses = [batch_responses, ]
181 |
182 | all_data = [item for sublist in all_batch_data for item in sublist]
183 | all_response = [item for sublist in all_batch_responses for item in sublist]
184 |
185 | for data, responses in zip(all_data, all_response):
186 | answer_list = []
187 | for response in responses:
188 | answer_list.append(response)
189 | data['model_answer'] = answer_list
190 | if accelerator.is_main_process:
191 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
192 | fp.flush()
193 |
194 | if __name__ == "__main__":
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
197 | parser.add_argument("--input_path", type=str, help="path to the input data")
198 | parser.add_argument("--output_path", type=str, help="path to the output data")
199 | parser.add_argument("--score_path", type=str, help="path to the score")
200 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
201 | parser.add_argument("--num_return", type=int, help="number of return sequences")
202 | parser.add_argument("--batch_size", type=int, help="batch size")
203 | args = parser.parse_args()
204 | generate_response(args)
205 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
206 |
207 |
208 |
--------------------------------------------------------------------------------
/src/evaluate/eval_meditron70b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4 | import re
5 | import argparse
6 | import json
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | import torch.distributed as dist
10 | from collections import defaultdict
11 | from llama_index.llms.vllm import Vllm
12 |
13 | llm = Vllm(
14 | model='./models/meditron-70b/',
15 | trust_remote_code=True,
16 | max_new_tokens=64,
17 | temperature=0,
18 | dtype="bfloat16",
19 | tensor_parallel_size=8,
20 | vllm_kwargs={"swap_space": 1},
21 | )
22 |
23 | def get_answer(data:str):
24 | response=llm.complete(
25 | data
26 | )
27 | return response.text
28 |
29 |
30 | def extract_and_choose_answer(pattern, model_answer):
31 | if '\n' in model_answer:
32 | model_answer_split = model_answer.split('\n')
33 | for model_answer_i in model_answer_split:
34 | if len(model_answer_i):
35 | model_answer = model_answer_i
36 | break
37 | matches = re.findall(pattern, model_answer)
38 | option_count = {}
39 | for match in matches:
40 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
41 |
42 | if not option_count:
43 | # else use loose pattern
44 | loose_pattern = r'[A-F]'
45 | if pattern == loose_pattern:
46 | if model_answer == 'Yes.':
47 | return 'A'
48 | elif model_answer == 'No.':
49 | return 'B'
50 | else:
51 | return None
52 | else:
53 | return extract_and_choose_answer(loose_pattern, model_answer)
54 |
55 | max_count = max(option_count.values())
56 | max_options = [option for option, count in option_count.items() if count == max_count]
57 | return max_options[0]
58 |
59 |
60 | def generate_score(result_path, score_path, wrong_item_path):
61 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
62 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
63 |
64 | all = defaultdict(int)
65 | right = defaultdict(int)
66 | accuracy_dict = defaultdict(int)
67 | wrong_item = []
68 |
69 | print(f'****Total:{len(json_objects)}****')
70 | debug = True
71 | for item in json_objects:
72 | source = item["source"]
73 | for answer in item["model_answer"]:
74 | all[source] += 1
75 | pattern = r'(\(]([A-Fa-f])[)\)'
76 | extract_answer = extract_and_choose_answer(pattern, answer)
77 | item['extract_answer'] = extract_answer
78 | if debug:
79 | debug = False
80 | print(f'extract_answer:{extract_answer}')
81 | right_answer = item['answer']
82 | print(f'right_answer:{right_answer}')
83 | if item['answer'] == extract_answer:
84 | right[source] += 1
85 | else:
86 | wrong_item.append(item)
87 |
88 |
89 | print(f'all:{all}')
90 | print(f'right:{right}')
91 |
92 | for key in right:
93 | accuracy_dict[key] = right[key] / all[key]
94 |
95 | with open(score_path, "w", encoding="utf8") as f:
96 | json.dump(accuracy_dict, f, indent=4)
97 |
98 | print(f'***********score_result save in {score_path}*************')
99 |
100 | with open(wrong_item_path, "w", encoding="utf8") as f:
101 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
102 |
103 | print(f'***********wrong_item save in {wrong_item_path}*************')
104 |
105 |
106 | def generate_response(args):
107 |
108 |
109 | # model_path = args.model_path
110 |
111 | fp = open(args.output_path,'w')
112 |
113 | with open(args.input_path) as f:
114 | data = json.load(f)
115 |
116 | for item in tqdm(data):
117 | question=item['question']
118 | answer=get_answer(question)
119 | item['model_answer']=answer
120 | fp.write(json.dumps(item, ensure_ascii=False) +'\n')
121 | fp.flush()
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = argparse.ArgumentParser()
126 | # parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
127 | parser.add_argument("--input_path", type=str, help="path to the input data")
128 | parser.add_argument("--output_path", type=str, help="path to the output data")
129 | parser.add_argument("--score_path", type=str, help="path to the score")
130 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
131 | args = parser.parse_args()
132 | generate_response(args)
133 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
134 |
135 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/src/evaluate/eval_mistral.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 | wrong_item = []
103 |
104 | print(f'****Total:{len(json_objects)}****')
105 | debug = True
106 | for item in json_objects:
107 | source = item["source"]
108 | for answer in item["model_answer"]:
109 | all[source] += 1
110 | pattern = r'[(\(]([A-Fa-f])[)\)]'
111 | extract_answer = extract_and_choose_answer(pattern, answer)
112 | item['extract_answer'] = extract_answer
113 | if debug:
114 | debug = False
115 | print(f'extract_answer:{extract_answer}')
116 | right_answer = item['answer']
117 | print(f'right_answer:{right_answer}')
118 | if item['answer'] == extract_answer:
119 | right[source] += 1
120 | else:
121 | wrong_item.append(item)
122 |
123 |
124 | print(f'all:{all}')
125 | print(f'right:{right}')
126 |
127 | for key in right:
128 | accuracy_dict[key] = right[key] / all[key]
129 |
130 | with open(score_path, "w", encoding="utf8") as f:
131 | json.dump(accuracy_dict, f, indent=4)
132 |
133 | print(f'***********score_result save in {score_path}*************')
134 |
135 | with open(wrong_item_path, "w", encoding="utf8") as f:
136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
137 |
138 | print(f'***********wrong_item save in {wrong_item_path}*************')
139 |
140 |
141 | def generate_response(args):
142 | accelerator = Accelerator()
143 |
144 | model_path = args.model_path
145 | accelerator.print(f'****************model_path:{model_path}******************')
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='null')
148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
149 | model = model.half().cuda()
150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
151 |
152 |
153 | dataset = TestDataset(args.input_path, tokenizer)
154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
155 |
156 | model = model.eval()
157 | if dist.is_initialized():
158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
159 |
160 | dataloader = accelerator.prepare(dataloader)
161 | accelerator.print(f'******************load_model from {model_path}******************')
162 |
163 | if accelerator.is_main_process:
164 | fp = open(args.output_path,'w')
165 |
166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
167 | for batch in dataloader_iterator:
168 | batch_input_ids = batch["input_ids"]
169 | batch_data = batch["data"]
170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
172 |
173 |
174 | if dist.is_initialized():
175 | all_batch_data = [None] * dist.get_world_size()
176 | all_batch_responses = [None] * dist.get_world_size()
177 | dist.all_gather_object(all_batch_responses, batch_responses)
178 | dist.all_gather_object(all_batch_data, batch_data)
179 | else:
180 | all_batch_data = [batch_data, ]
181 | all_batch_responses = [batch_responses, ]
182 |
183 | all_data = [item for sublist in all_batch_data for item in sublist]
184 | all_response = [item for sublist in all_batch_responses for item in sublist]
185 |
186 | for data, responses in zip(all_data, all_response):
187 | answer_list = []
188 | for response in responses:
189 | answer_list.append(response)
190 | data['model_answer'] = answer_list
191 | if accelerator.is_main_process:
192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
193 | fp.flush()
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
198 | parser.add_argument("--input_path", type=str, help="path to the input data")
199 | parser.add_argument("--output_path", type=str, help="path to the output data")
200 | parser.add_argument("--score_path", type=str, help="path to the score")
201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
202 | parser.add_argument("--num_return", type=int, help="number of return sequences")
203 | parser.add_argument("--batch_size", type=int, help="batch size")
204 | args = parser.parse_args()
205 | generate_response(args)
206 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
207 |
208 |
209 |
--------------------------------------------------------------------------------
/src/evaluate/eval_mmedlm2.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 | wrong_item = []
103 |
104 | print(f'****Total:{len(json_objects)}****')
105 | debug = True
106 | for item in json_objects:
107 | source = item["source"]
108 | for answer in item["model_answer"]:
109 | all[source] += 1
110 | pattern = r'[(\(]([A-Fa-f])[)\)]'
111 | extract_answer = extract_and_choose_answer(pattern, answer)
112 | item['extract_answer'] = extract_answer
113 | if debug:
114 | debug = False
115 | print(f'extract_answer:{extract_answer}')
116 | right_answer = item['answer']
117 | print(f'right_answer:{right_answer}')
118 | if item['answer'] == extract_answer:
119 | right[source] += 1
120 | else:
121 | wrong_item.append(item)
122 |
123 |
124 | print(f'all:{all}')
125 | print(f'right:{right}')
126 |
127 | for key in right:
128 | accuracy_dict[key] = right[key] / all[key]
129 |
130 | with open(score_path, "w", encoding="utf8") as f:
131 | json.dump(accuracy_dict, f, indent=4)
132 |
133 | print(f'***********score_result save in {score_path}*************')
134 |
135 | with open(wrong_item_path, "w", encoding="utf8") as f:
136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
137 |
138 | print(f'***********wrong_item save in {wrong_item_path}*************')
139 |
140 |
141 | def generate_response(args):
142 | accelerator = Accelerator()
143 |
144 | model_path = args.model_path
145 | accelerator.print(f'****************model_path:{model_path}******************')
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='',trust_remote_code=True)
148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
149 | model = model.half().cuda()
150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
151 |
152 |
153 | dataset = TestDataset(args.input_path, tokenizer)
154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
155 |
156 | model = model.eval()
157 | if dist.is_initialized():
158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
159 |
160 | dataloader = accelerator.prepare(dataloader)
161 | accelerator.print(f'******************load_model from {model_path}******************')
162 |
163 | if accelerator.is_main_process:
164 | fp = open(args.output_path,'w')
165 |
166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
167 | for batch in dataloader_iterator:
168 | batch_input_ids = batch["input_ids"]
169 | batch_data = batch["data"]
170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
172 |
173 |
174 | if dist.is_initialized():
175 | all_batch_data = [None] * dist.get_world_size()
176 | all_batch_responses = [None] * dist.get_world_size()
177 | dist.all_gather_object(all_batch_responses, batch_responses)
178 | dist.all_gather_object(all_batch_data, batch_data)
179 | else:
180 | all_batch_data = [batch_data, ]
181 | all_batch_responses = [batch_responses, ]
182 |
183 | all_data = [item for sublist in all_batch_data for item in sublist]
184 | all_response = [item for sublist in all_batch_responses for item in sublist]
185 |
186 | for data, responses in zip(all_data, all_response):
187 | answer_list = []
188 | for response in responses:
189 | answer_list.append(response)
190 | data['model_answer'] = answer_list
191 | if accelerator.is_main_process:
192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
193 | fp.flush()
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
198 | parser.add_argument("--input_path", type=str, help="path to the input data")
199 | parser.add_argument("--output_path", type=str, help="path to the output data")
200 | parser.add_argument("--score_path", type=str, help="path to the score")
201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
202 | parser.add_argument("--num_return", type=int, help="number of return sequences")
203 | parser.add_argument("--batch_size", type=int, help="batch size")
204 | args = parser.parse_args()
205 | generate_response(args)
206 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
207 |
208 |
209 |
--------------------------------------------------------------------------------
/src/evaluate/eval_qwen.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 |
103 | print(f'****Total:{len(json_objects)}****')
104 | debug = True
105 | for item in json_objects:
106 | source = item["source"]
107 | for answer in item["model_answer"]:
108 | all[source] += 1
109 | pattern = r'[(\(]([A-Fa-f])[)\)]'
110 | extract_answer = extract_and_choose_answer(pattern, answer)
111 | if debug:
112 | debug = False
113 | print(f'extract_answer:{extract_answer}')
114 | right_answer = item['answer']
115 | print(f'right_answer:{right_answer}')
116 | if item['answer'] == extract_answer:
117 | right[source] += 1
118 |
119 |
120 | print(f'all:{all}')
121 | print(f'right:{right}')
122 |
123 | for key in right:
124 | accuracy_dict[key] = right[key] / all[key]
125 |
126 | with open(score_path, "w", encoding="utf8") as f:
127 | json.dump(accuracy_dict, f, indent=4)
128 |
129 | print(f'***********score_result save in {score_path}*************')
130 |
131 |
132 | def generate_response(args):
133 | accelerator = Accelerator()
134 |
135 | model_path = args.model_path
136 | accelerator.print(f'****************model_path:{model_path}******************')
137 |
138 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left', pad_token='<|extra_0|>', eos_token='<|endoftext|>')
139 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, ignore_mismatched_sizes=True).half()
140 | generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=tokenizer.pad_token_id, num_return_sequences=args.num_return, max_new_tokens=256, min_new_tokens=2, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
141 |
142 | dataset = TestDataset(args.input_path, tokenizer)
143 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
144 |
145 | model = model.eval()
146 | if dist.is_initialized():
147 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
148 |
149 | dataloader = accelerator.prepare(dataloader)
150 | accelerator.print(f'******************load_model from {model_path}******************')
151 |
152 | if accelerator.is_main_process:
153 | fp = open(args.output_path,'w')
154 |
155 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
156 | for batch in dataloader_iterator:
157 | batch_input_ids = batch["input_ids"]
158 | batch_data = batch["data"]
159 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, generation_config=generation_config)
160 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
161 |
162 |
163 | if dist.is_initialized():
164 | all_batch_data = [None] * dist.get_world_size()
165 | all_batch_responses = [None] * dist.get_world_size()
166 | dist.all_gather_object(all_batch_responses, batch_responses)
167 | dist.all_gather_object(all_batch_data, batch_data)
168 | else:
169 | all_batch_data = [batch_data, ]
170 | all_batch_responses = [batch_responses, ]
171 |
172 | all_data = [item for sublist in all_batch_data for item in sublist]
173 | all_response = [item for sublist in all_batch_responses for item in sublist]
174 |
175 | for data, responses in zip(all_data, all_response):
176 | answer_list = []
177 | for response in responses:
178 | answer_list.append(response)
179 | data['model_answer'] = answer_list
180 | if accelerator.is_main_process:
181 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
182 | fp.flush()
183 |
184 | if __name__ == "__main__":
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
187 | parser.add_argument("--input_path", type=str, help="path to the input data")
188 | parser.add_argument("--output_path", type=str, help="path to the output data")
189 | parser.add_argument("--score_path", type=str, help="path to the score")
190 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
191 | parser.add_argument("--num_return", type=int, help="number of return sequences")
192 | parser.add_argument("--batch_size", type=int, help="batch size")
193 | args = parser.parse_args()
194 | generate_response(args)
195 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
196 |
--------------------------------------------------------------------------------
/src/evaluate/eval_yi.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 | matches = re.findall(pattern, model_answer)
72 | option_count = {}
73 | for match in matches:
74 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
75 |
76 | if not option_count:
77 | # else use loose pattern
78 | loose_pattern = r'[A-F]'
79 | if pattern == loose_pattern:
80 | if model_answer == 'Yes.':
81 | return 'A'
82 | elif model_answer == 'No.':
83 | return 'B'
84 | else:
85 | return None
86 | else:
87 | return extract_and_choose_answer(loose_pattern, model_answer)
88 |
89 | max_count = max(option_count.values())
90 | max_options = [option for option, count in option_count.items() if count == max_count]
91 | return max_options[0]
92 |
93 |
94 |
95 | def generate_score(result_path, score_path, wrong_item_path):
96 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
97 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
98 |
99 | all = defaultdict(int)
100 | right = defaultdict(int)
101 | accuracy_dict = defaultdict(int)
102 | wrong_item = []
103 |
104 | print(f'****Total:{len(json_objects)}****')
105 | debug = True
106 | for item in json_objects:
107 | source = item["source"]
108 | for answer in item["model_answer"]:
109 | all[source] += 1
110 | pattern = r'[(\(]([A-Fa-f])[)\)]'
111 | extract_answer = extract_and_choose_answer(pattern, answer)
112 | item['extract_answer'] = extract_answer
113 | if debug:
114 | debug = False
115 | print(f'extract_answer:{extract_answer}')
116 | right_answer = item['answer']
117 | print(f'right_answer:{right_answer}')
118 | if item['answer'] == extract_answer:
119 | right[source] += 1
120 | else:
121 | wrong_item.append(item)
122 |
123 |
124 | print(f'all:{all}')
125 | print(f'right:{right}')
126 |
127 | for key in right:
128 | accuracy_dict[key] = right[key] / all[key]
129 |
130 | with open(score_path, "w", encoding="utf8") as f:
131 | json.dump(accuracy_dict, f, indent=4)
132 |
133 | print(f'***********score_result save in {score_path}*************')
134 |
135 | with open(wrong_item_path, "w", encoding="utf8") as f:
136 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
137 |
138 | print(f'***********wrong_item save in {wrong_item_path}*************')
139 |
140 |
141 | def generate_response(args):
142 | accelerator = Accelerator()
143 |
144 | model_path = args.model_path
145 | accelerator.print(f'****************model_path:{model_path}******************')
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='<|endoftext|>')
148 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
149 | model = model.half().cuda()
150 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
151 |
152 |
153 | dataset = TestDataset(args.input_path, tokenizer)
154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
155 |
156 | model = model.eval()
157 | if dist.is_initialized():
158 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
159 |
160 | dataloader = accelerator.prepare(dataloader)
161 | accelerator.print(f'******************load_model from {model_path}******************')
162 |
163 | if accelerator.is_main_process:
164 | fp = open(args.output_path,'w')
165 |
166 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
167 | for batch in dataloader_iterator:
168 | batch_input_ids = batch["input_ids"]
169 | batch_data = batch["data"]
170 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
171 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
172 |
173 |
174 | if dist.is_initialized():
175 | all_batch_data = [None] * dist.get_world_size()
176 | all_batch_responses = [None] * dist.get_world_size()
177 | dist.all_gather_object(all_batch_responses, batch_responses)
178 | dist.all_gather_object(all_batch_data, batch_data)
179 | else:
180 | all_batch_data = [batch_data, ]
181 | all_batch_responses = [batch_responses, ]
182 |
183 | all_data = [item for sublist in all_batch_data for item in sublist]
184 | all_response = [item for sublist in all_batch_responses for item in sublist]
185 |
186 | for data, responses in zip(all_data, all_response):
187 | answer_list = []
188 | for response in responses:
189 | answer_list.append(response)
190 | data['model_answer'] = answer_list
191 | if accelerator.is_main_process:
192 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
193 | fp.flush()
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
198 | parser.add_argument("--input_path", type=str, help="path to the input data")
199 | parser.add_argument("--output_path", type=str, help="path to the output data")
200 | parser.add_argument("--score_path", type=str, help="path to the score")
201 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
202 | parser.add_argument("--num_return", type=int, help="number of return sequences")
203 | parser.add_argument("--batch_size", type=int, help="batch size")
204 | args = parser.parse_args()
205 | generate_response(args)
206 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
207 |
208 |
209 |
--------------------------------------------------------------------------------
/src/evaluate/eval_zephyr.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 |
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, data_path,tokenizer):
17 | self.data = []
18 | with open(data_path) as f:
19 | self.data = json.load(f)
20 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
21 | if dist_flag_0:
22 | print(f'load {len(self.data)} data from {data_path}')
23 | self.tokenizer = tokenizer
24 | self.debug = True
25 |
26 | def __getitem__(self, index):
27 | item = self.data[index]
28 | return {
29 | 'data': item,
30 | 'input': item['question']
31 | }
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def collate_fn(self, batch):
37 | batch_query = [x['input'] for x in batch]
38 | batch_data = [x['data'] for x in batch]
39 | out_batch = {}
40 | out_batch['data'] = batch_data
41 | out_batch['input_ids'] = self.tokenizer(batch_query, return_tensors='pt', padding=True)['input_ids']
42 | dist_flag_0 = True if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) else False
43 | if self.debug and dist_flag_0:
44 | decoded_texts = self.tokenizer.batch_decode(out_batch['input_ids'], skip_special_tokens=False)
45 | for idx, sample in enumerate(decoded_texts):
46 | print(f'*******************batch_texts[{idx}]**********************************')
47 | print(sample)
48 | self.debug = False
49 | return out_batch
50 |
51 |
52 | def get_response(batch_input_ids, batch_output_ids, tokenizer, num_return):
53 | responses_list=[]
54 | batch_return=[]
55 | input_len = len(batch_input_ids[0])
56 | for idx, output_ids in enumerate(batch_output_ids):
57 | generated_ids = output_ids[input_len:]
58 | batch_return.append(tokenizer.decode(generated_ids, skip_special_tokens=True))
59 | if idx % num_return == num_return-1:
60 | responses_list.append(batch_return)
61 | batch_return=[]
62 | return responses_list
63 |
64 | def extract_and_choose_answer(pattern, model_answer):
65 | if '\n' in model_answer:
66 | model_answer_split = model_answer.split('\n')
67 | for model_answer_i in model_answer_split:
68 | if len(model_answer_i):
69 | model_answer = model_answer_i
70 | break
71 |
72 | matches = re.findall(pattern, model_answer)
73 | option_count = {}
74 | for match in matches:
75 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
76 |
77 | if not option_count:
78 | # else use loose pattern
79 | loose_pattern = r'[A-F]'
80 | if pattern == loose_pattern:
81 | if model_answer == 'Yes.':
82 | return 'A'
83 | elif model_answer == 'No.':
84 | return 'B'
85 | else:
86 | return None
87 | else:
88 | return extract_and_choose_answer(loose_pattern, model_answer)
89 |
90 | max_count = max(option_count.values())
91 | max_options = [option for option, count in option_count.items() if count == max_count]
92 | return max_options[0]
93 |
94 |
95 |
96 | def generate_score(result_path, score_path, wrong_item_path):
97 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
98 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
99 |
100 | all = defaultdict(int)
101 | right = defaultdict(int)
102 | accuracy_dict = defaultdict(int)
103 | wrong_item = []
104 |
105 | print(f'****Total:{len(json_objects)}****')
106 | debug = True
107 | for item in json_objects:
108 | source = item["source"]
109 | for answer in item["model_answer"]:
110 | all[source] += 1
111 | pattern = r'[(\(]([A-Fa-f])[)\)]'
112 | extract_answer = extract_and_choose_answer(pattern, answer)
113 | item['extract_answer'] = extract_answer
114 | if debug:
115 | debug = False
116 | print(f'extract_answer:{extract_answer}')
117 | right_answer = item['answer']
118 | print(f'right_answer:{right_answer}')
119 | if item['answer'] == extract_answer:
120 | right[source] += 1
121 | else:
122 | wrong_item.append(item)
123 |
124 |
125 | print(f'all:{all}')
126 | print(f'right:{right}')
127 |
128 | for key in right:
129 | accuracy_dict[key] = right[key] / all[key]
130 |
131 | with open(score_path, "w", encoding="utf8") as f:
132 | json.dump(accuracy_dict, f, indent=4)
133 |
134 | print(f'***********score_result save in {score_path}*************')
135 |
136 | with open(wrong_item_path, "w", encoding="utf8") as f:
137 | json.dump(wrong_item, f, indent=4, ensure_ascii=False)
138 |
139 | print(f'***********wrong_item save in {wrong_item_path}*************')
140 |
141 |
142 | def generate_response(args):
143 | accelerator = Accelerator()
144 |
145 | model_path = args.model_path
146 | accelerator.print(f'****************model_path:{model_path}******************')
147 |
148 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left', eos_token='',pad_token='')
149 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
150 | model = model.half().cuda()
151 | gen_kwargs = {'num_return_sequences': args.num_return, 'max_new_tokens': 128, 'min_new_tokens':2, 'do_sample':False}
152 |
153 |
154 | dataset = TestDataset(args.input_path, tokenizer)
155 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)
156 |
157 | model = model.eval()
158 | if dist.is_initialized():
159 | accelerator.print(f'****************dist.get_world_size():{dist.get_world_size()}******************')
160 |
161 | dataloader = accelerator.prepare(dataloader)
162 | accelerator.print(f'******************load_model from {model_path}******************')
163 |
164 | if accelerator.is_main_process:
165 | fp = open(args.output_path,'w')
166 |
167 | dataloader_iterator = tqdm(dataloader, total=len(dataloader)) if accelerator.is_main_process else dataloader
168 | for batch in dataloader_iterator:
169 | batch_input_ids = batch["input_ids"]
170 | batch_data = batch["data"]
171 | batch_output_ids = accelerator.unwrap_model(model).generate(batch_input_ids, **gen_kwargs)
172 | batch_responses = get_response(batch_input_ids, batch_output_ids, tokenizer, args.num_return)
173 |
174 |
175 | if dist.is_initialized():
176 | all_batch_data = [None] * dist.get_world_size()
177 | all_batch_responses = [None] * dist.get_world_size()
178 | dist.all_gather_object(all_batch_responses, batch_responses)
179 | dist.all_gather_object(all_batch_data, batch_data)
180 | else:
181 | all_batch_data = [batch_data, ]
182 | all_batch_responses = [batch_responses, ]
183 |
184 | all_data = [item for sublist in all_batch_data for item in sublist]
185 | all_response = [item for sublist in all_batch_responses for item in sublist]
186 |
187 | for data, responses in zip(all_data, all_response):
188 | answer_list = []
189 | for response in responses:
190 | answer_list.append(response)
191 | data['model_answer'] = answer_list
192 | if accelerator.is_main_process:
193 | fp.write(json.dumps(data, ensure_ascii=False) +'\n')
194 | fp.flush()
195 |
196 | if __name__ == "__main__":
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument("--model_path", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
199 | parser.add_argument("--input_path", type=str, help="path to the input data")
200 | parser.add_argument("--output_path", type=str, help="path to the output data")
201 | parser.add_argument("--score_path", type=str, help="path to the score")
202 | parser.add_argument("--wrong_item_path", type=str, help="path to the wrong_item")
203 | parser.add_argument("--num_return", type=int, help="number of return sequences")
204 | parser.add_argument("--batch_size", type=int, help="batch size")
205 | args = parser.parse_args()
206 | # generate_response(args)
207 | generate_score(args.output_path, args.score_path, args.wrong_item_path)
208 |
209 |
210 |
--------------------------------------------------------------------------------
/src/evaluate/generate_score.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5 | import re
6 | import argparse
7 | from accelerate import Accelerator
8 | import json
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from collections import defaultdict
13 | import json
14 |
15 | def extract_and_choose_answer(pattern, model_answer):
16 | if '\n' in model_answer:
17 | model_answer_split = model_answer.split('\n')
18 | for model_answer_i in model_answer_split:
19 | if len(model_answer_i):
20 | model_answer = model_answer_i
21 | break
22 |
23 | matches = re.findall(pattern, model_answer)
24 | option_count = {}
25 | for match in matches:
26 | option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
27 |
28 | if not option_count:
29 | # else use loose pattern
30 | loose_pattern = r'[A-F]'
31 | if pattern == loose_pattern:
32 | if model_answer == 'Yes.':
33 | return 'A'
34 | elif model_answer == 'No.':
35 | return 'B'
36 | else:
37 | return None
38 | else:
39 | return extract_and_choose_answer(loose_pattern, model_answer)
40 |
41 | max_count = max(option_count.values())
42 | max_options = [option for option, count in option_count.items() if count == max_count]
43 | return max_options[0]
44 |
45 |
46 |
47 | def generate_score(result_path, score_path):
48 | with open(result_path, 'r', encoding='utf-8') as jsonl_file:
49 | json_objects = [json.loads(line.strip()) for line in jsonl_file]
50 |
51 | all = defaultdict(int)
52 | right = defaultdict(int)
53 | accuracy_dict = defaultdict(int)
54 |
55 | print(f'****Total:{len(json_objects)}****')
56 | debug = True
57 | for item in json_objects:
58 | source = item["source"]
59 | answer = item["model_answer"]
60 | all[source] += 1
61 | pattern = r'[(\(]([A-Fa-f])[)\)]'
62 | extract_answer = extract_and_choose_answer(pattern, answer)
63 | if debug:
64 | debug = False
65 | print(f'extract_answer:{extract_answer}')
66 | print(answer)
67 | right_answer = item['answer']
68 | print(f'right_answer:{right_answer}')
69 | if item['answer'] == extract_answer:
70 | right[source] += 1
71 |
72 |
73 | print(f'all:{all}')
74 | print(f'right:{right}')
75 |
76 | for key in right:
77 | accuracy_dict[key] = right[key] / all[key]
78 |
79 | with open(score_path, "w", encoding="utf8") as f:
80 | json.dump(accuracy_dict, f, indent=4)
81 |
82 | print(f'***********score_result save in {score_path}*************')
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument("--output_path", '-a', type=str, help="path to the output data")
88 | parser.add_argument("--score_path", '-o', type=str, help="path to the score")
89 | args = parser.parse_args()
90 | generate_score(args.output_path, args.score_path)
91 | # merge_result()
92 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/OpenAIGPT.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import openai
3 | from retrying import retry
4 | import random
5 |
6 | class OpenAIGPT:
7 | def __init__(self, model_name="gpt-3.5-turbo", keys_path=None):
8 | self.model_name = model_name
9 | with open(keys_path, encoding="utf-8", mode="r") as fr:
10 | self.keys = [line.strip() for line in fr if len(line.strip()) >= 4]
11 |
12 | def __post_process(self, response):
13 | return response["choices"][0]["message"]["content"]
14 |
15 | @retry(wait_fixed=200, stop_max_attempt_number=50)
16 | def __call__(self, message):
17 | if message is None or message == "":
18 | return False, "Your input is empty."
19 |
20 | current_key = random.choice(self.keys)
21 | openai.api_key = current_key
22 | # openai.organization = "org-DvpMM9lXmpMxnaygs1ixEvZw"
23 | response = openai.ChatCompletion.create(
24 | model=self.model_name,
25 | messages=[{"role": "user", "content": message}],
26 | temperature=0.6,
27 | top_p=0.8,
28 | frequency_penalty=0.6,
29 | presence_penalty=0.8,
30 | n=1,
31 | )
32 | return self.__post_process(response)
33 |
34 |
35 | if __name__ == "__main__":
36 | # test code
37 | igpt = OpenAIGPT(keys_path="gpt4key.txt", model_name="gpt-4")
38 | string = "Hello"
39 | print(string)
40 | answer = igpt(string)
41 | print(answer)
42 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/OpenAIGPT_datagen_multithread.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | from concurrent.futures import ThreadPoolExecutor
3 | from tqdm import tqdm
4 | import os
5 | import argparse
6 | from OpenAIGPT import OpenAIGPT
7 |
8 |
9 | def OpenAIGPT_datagen(args):
10 | igpt = OpenAIGPT(model_name=args.model_name, keys_path=args.keys_path)
11 |
12 | def process_item(item):
13 | content = igpt(item["query"])
14 | item["model_answer"] = content
15 | return item
16 |
17 | output_path = args.output_path
18 | input_path = args.input_path
19 |
20 | # Collect the IDs of processed items in the output file
21 | processed_ids = set()
22 | if os.path.exists(output_path):
23 | with jsonlines.open(output_path, "r") as f:
24 | for item in f:
25 | processed_ids.add(item.get("id", None))
26 |
27 | # Collect unprocessed items
28 | items_to_process = []
29 |
30 | with jsonlines.open(input_path, "r") as reader:
31 | for item in reader:
32 | item_id = item.get("id", None)
33 | if item_id is not None and item_id in processed_ids:
34 | continue
35 | items_to_process.append(item)
36 |
37 | # Multi-threaded parallel processing
38 | with jsonlines.open(
39 | output_path, "a" if os.path.exists(output_path) else "w"
40 | ) as writer:
41 | with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
42 | futures = {
43 | executor.submit(process_item, item): item for item in items_to_process
44 | }
45 |
46 | # Use tqdm to display progress
47 | for future in tqdm(
48 | futures, total=len(items_to_process), desc="Processing items"
49 | ):
50 | item = future.result()
51 | writer.write(item)
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser(description="Process JSONL files concurrently.")
56 | parser.add_argument(
57 | "--model_name",
58 | type=str,
59 | default="gpt-3.5-turbo",
60 | help="Name of the OpenAIGPT model to use.",
61 | )
62 | parser.add_argument(
63 | "--keys_path",
64 | type=str,
65 | required=True,
66 | help="API key for the OpenAIGPT service.",
67 | )
68 | parser.add_argument(
69 | "--input_path", type=str, required=True, help="Path to the input JSONL file."
70 | )
71 | parser.add_argument(
72 | "--output_path", type=str, required=True, help="Path to the output JSONL file."
73 | )
74 | parser.add_argument(
75 | "--max_workers",
76 | type=int,
77 | default=10,
78 | help="Maximum number of workers for concurrent processing.",
79 | )
80 |
81 | args = parser.parse_args()
82 | OpenAIGPT_datagen(args)
83 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/gpt_key.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreedomIntelligence/Apollo/505c48e82207b36cf3921df61fcd58895fb6ca4e/src/process/openai_rewrite/gpt_key.txt
--------------------------------------------------------------------------------
/src/process/openai_rewrite/guidelines_en/1.2.prepare_data.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 |
5 | ans_prompt = """You are Medbase, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's in English. In formulating your response, you must thoughtfully reference the , ensuring that your reply does not disclose your reliance on . Aim to provide a comprehensive and informative response, incorporating relevant insights from to best assist the user. Please be cautious to avoid including any content that might raise ethical concerns.
6 |
7 | : {question}
8 |
9 | : {reference}
10 |
11 | : """
12 |
13 |
14 | def generate_query(data):
15 | chatgpt_query = ans_prompt.format_map(
16 | {"question": data["model_answer"], "reference": data["reference"]}
17 | )
18 | return chatgpt_query
19 |
20 |
21 | def Prepare_data(args):
22 | data = []
23 | # Read the uploaded JSONl file
24 | with jsonlines.open(args.input_path, "r") as reader:
25 | data = list(reader)
26 |
27 | print(f"len:{len(data)}")
28 | # Convert as required
29 | jsonl_data = []
30 |
31 | for id, item in enumerate(data):
32 | jsonl_data.append(
33 | {
34 | "id": id,
35 | "query": generate_query(item),
36 | "model_answer": "",
37 | "model_question": item["model_answer"],
38 | "reference": item["reference"],
39 | }
40 | )
41 |
42 | # Save the converted data as a JSONL file
43 | with open(args.output_path, "w", encoding="utf-8") as file:
44 | for entry in jsonl_data:
45 | file.write(json.dumps(entry, ensure_ascii=False) + "\n")
46 |
47 | print(f"Prepare finished, output to '{args.output_path}'")
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser(
52 | description="Prepare data for OpenAIGPT generation"
53 | )
54 | parser.add_argument(
55 | "--input_path", type=str, required=True, help="Path to the input JSON file."
56 | )
57 | parser.add_argument(
58 | "--output_path", type=str, required=True, help="Path to the output JSONL file."
59 | )
60 | args = parser.parse_args()
61 | Prepare_data(args)
62 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/guidelines_en/1.prepare_data.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 |
5 | query_prompt = """Please create a that closely aligns with the provided . Ensure that the is formulated in English and does not explicitly reference the text. You may incorporate specific scenarios or contexts in the , allowing the to serve as a comprehensive and precise answer.
6 |
7 | : {text}
8 |
9 | : """
10 |
11 |
12 | def generate_query(data):
13 | chatgpt_query = query_prompt.format_map({"text": data[0]})
14 | return chatgpt_query
15 |
16 |
17 | def Prepare_data(args):
18 | data = []
19 | # Read the uploaded JSONl file
20 | with jsonlines.open(args.input_path, "r") as reader:
21 | data = list(reader)
22 |
23 | print(f"len:{len(data)}")
24 | # Convert as required
25 | jsonl_data = []
26 |
27 | for id, item in enumerate(data):
28 | jsonl_data.append(
29 | {
30 | "id": id,
31 | "query": generate_query(item),
32 | "model_answer": "",
33 | "reference": item[0],
34 | }
35 | )
36 |
37 | # Save the converted data as a JSONL file
38 | with open(args.output_path, "w", encoding="utf-8") as file:
39 | for entry in jsonl_data:
40 | file.write(json.dumps(entry, ensure_ascii=False) + "\n")
41 |
42 | print(f"Prepare finished, output to '{args.output_path}'")
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser(
47 | description="Prepare data for OpenAIGPT generation"
48 | )
49 | parser.add_argument(
50 | "--input_path", type=str, required=True, help="Path to the input JSON file."
51 | )
52 | parser.add_argument(
53 | "--output_path", type=str, required=True, help="Path to the output JSONL file."
54 | )
55 | args = parser.parse_args()
56 | Prepare_data(args)
57 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/guidelines_en/1.run_prepare_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python 1.prepare_data.py --input_path ./data/1.dev.jsonl --output_path ./data/2.dev_prepared.jsonl
4 | python 1.2.prepare_data.py --input_path ./data/3.dev_aftgpt.jsonl --output_path ./data/2.dev_prepared.jsonl
--------------------------------------------------------------------------------
/src/process/openai_rewrite/guidelines_en/2.run_gpt_datagen_multithread.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_prepared.jsonl --output_path ./data/3.dev_aftgpt.jsonl --max_workers 300
4 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_aftgpt_prepared.jsonl --output_path ./data/3.dev_aftgpt_prepared_aftgpt.jsonl --max_workers 300
--------------------------------------------------------------------------------
/src/process/openai_rewrite/guidelines_en/3.extract.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def extract_and_save(input_jsonl, output_json):
5 | # Create an empty list to store the extracted data
6 | extracted_data = []
7 | train_data = []
8 | count = 0
9 | # Read the JSONL file line by line and extract the specified fields
10 | with open(input_jsonl, "r", encoding="utf-8") as file:
11 | for line in file:
12 | json_data = json.loads(line)
13 | if len(json_data.get("model_answer", "")) < 32:
14 | continue
15 |
16 | #Extract specified fields
17 | extracted_item = {
18 | "question": json_data.get("model_question", ""),
19 | "answer": json_data.get("model_answer", ""),
20 | "reference": json_data.get("reference", ""),
21 | "id": json_data.get("id", ""),
22 | }
23 | list_item = [
24 | json_data.get("model_question", ""),
25 | json_data.get("model_answer", ""),
26 | ]
27 |
28 | # Add the extracted data to the list
29 | extracted_data.append(extracted_item)
30 | train_data.append(list_item)
31 | count += 1
32 |
33 | # Write the extracted data into a new JSON file
34 | print(f"sum_count:{count}")
35 | with open(output_json, "w", encoding="utf-8") as output:
36 | json.dump(extracted_data, output, indent=2, ensure_ascii=False)
37 | with open(train_json, "w", encoding="utf-8") as output:
38 | json.dump(train_data, output, indent=2, ensure_ascii=False)
39 |
40 |
41 | #Specify the input JSONL file and output JSON file name
42 | input_jsonl = "./data/3.dev_aftgpt_prepared_aftgpt.jsonl"
43 | output_json = "./data/4.dev_extracted.json"
44 | train_json = "./data/4.dev_train.json"
45 |
46 | # Call the extraction and saving functions
47 | extract_and_save(input_jsonl, output_json)
48 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/patient_en/1.2.prepare_data.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 |
5 | ans_prompt = """You are Medbase, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's in English. In formulating your response, you must thoughtfully reference the , ensuring that your reply does not disclose your reliance on . Aim to provide a comprehensive and informative response, incorporating relevant insights from to best assist the user. Please be cautious to avoid including any content that might raise ethical concerns.
6 |
7 | : {question}
8 |
9 | : {reference}
10 |
11 | : """
12 |
13 |
14 | def generate_query(data):
15 | chatgpt_query = ans_prompt.format_map(
16 | {"question": data["model_answer"], "reference": data["reference"]}
17 | )
18 | return chatgpt_query
19 |
20 |
21 | def Prepare_data(args):
22 | data = []
23 | # Read the uploaded JSONl file
24 | with jsonlines.open(args.input_path, "r") as reader:
25 | data = list(reader)
26 |
27 | print(f"len:{len(data)}")
28 | # Convert as required
29 | jsonl_data = []
30 |
31 | for id, item in enumerate(data):
32 | jsonl_data.append(
33 | {
34 | "id": id,
35 | "query": generate_query(item),
36 | "model_answer": "",
37 | "model_question": item["model_answer"],
38 | "reference": item["reference"],
39 | }
40 | )
41 |
42 | # Save the converted data as a JSONL file
43 | with open(args.output_path, "w", encoding="utf-8") as file:
44 | for entry in jsonl_data:
45 | file.write(json.dumps(entry, ensure_ascii=False) + "\n")
46 |
47 | print(f"Prepare finished, output to '{args.output_path}'")
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser(
52 | description="Prepare data for OpenAIGPT generation"
53 | )
54 | parser.add_argument(
55 | "--input_path", type=str, required=True, help="Path to the input JSON file."
56 | )
57 | parser.add_argument(
58 | "--output_path", type=str, required=True, help="Path to the output JSONL file."
59 | )
60 | args = parser.parse_args()
61 | Prepare_data(args)
62 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/patient_en/1.prepare_data.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 |
5 | query_prompt = """
6 | {text}
7 | Please create some dialogues between patients and doctors in English based on the above text. The format is:
8 | Patient’s question
9 | Doctor’s answer
10 | Both patient questions and doctor responses are as complex and detailed as possible."""
11 |
12 |
13 | def generate_query(data):
14 | chatgpt_query = query_prompt.format_map({"text": data[0]})
15 | return chatgpt_query
16 |
17 |
18 | def Prepare_data(args):
19 | data = []
20 | # Read the uploaded JSONl file
21 | with jsonlines.open(args.input_path, "r") as reader:
22 | data = list(reader)
23 |
24 | print(f"len:{len(data)}")
25 | # Convert as required
26 | jsonl_data = []
27 |
28 | for id, item in enumerate(data):
29 | query = generate_query(item)
30 | if len(query) > 4090:
31 | continue
32 | jsonl_data.append(
33 | {
34 | "id": id,
35 | "query": generate_query(item),
36 | "model_answer": "",
37 | "reference": item[0],
38 | }
39 | )
40 |
41 | # Save the converted data as a JSONL file
42 | with open(args.output_path, "w", encoding="utf-8") as file:
43 | for entry in jsonl_data:
44 | file.write(json.dumps(entry, ensure_ascii=False) + "\n")
45 |
46 | print(f"Prepare finished, output to '{args.output_path}'")
47 |
48 |
49 | if __name__ == "__main__":
50 | parser = argparse.ArgumentParser(
51 | description="Prepare data for OpenAIGPT generation"
52 | )
53 | parser.add_argument(
54 | "--input_path", type=str, required=True, help="Path to the input JSON file."
55 | )
56 | parser.add_argument(
57 | "--output_path", type=str, required=True, help="Path to the output JSONL file."
58 | )
59 | args = parser.parse_args()
60 | Prepare_data(args)
61 |
--------------------------------------------------------------------------------
/src/process/openai_rewrite/patient_en/1.run_prepare_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set the paths for input and output files
4 | input_path="./data/1.rlhf.jsonl"
5 | output_path="./data/2.rlhf_prepared.jsonl"
6 |
7 | # run Python scripts
8 | python 1.prepare_data.py --input_path "$input_path" --output_path "$output_path"
9 |
10 | python 1.prepare_data.py --input_path ./data/1.dev_en.jsonl --output_path ./data/2.dev_prepared.jsonl
11 | python 1.2.prepare_data.py --input_path ./data/3.dev_aftgpt.jsonl --output_path ./data/2.dev_aftgpt_prepared.jsonl
--------------------------------------------------------------------------------
/src/process/openai_rewrite/patient_en/2.run_gpt_datagen_multithread.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python "$python_script" --keys_path "$keys_path" --input_path "$input_path" --output_path "$output_path" --max_workers $max_workers
4 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_prepared.jsonl --output_path ./data/3.dev_aftgpt.jsonl --max_workers 300
5 | python ../OpenAIGPT_datagen_multithread.py --keys_path ../gpt4key.txt --input_path ./data/2.dev_aftgpt_prepared.jsonl --output_path ./data/3.dev_aftgpt_prepared_aftgpt.jsonl --max_workers 300
--------------------------------------------------------------------------------
/src/process/openai_rewrite/patient_en/3.extract.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 |
4 |
5 | def extract_dialogues_from_model_answer(model_answer):
6 | dialogues = []
7 | try:
8 | # Use regular expressions to extract conversations between patients and doctors
9 | patient_dialogues = re.findall(
10 | r"(.*?)", model_answer, re.DOTALL
11 | )
12 | doctor_dialogues = re.findall(
13 | r"(.*?)", model_answer, re.DOTALL
14 | )
15 |
16 | # Combine conversations into a list
17 | dialogues = list(zip(patient_dialogues, doctor_dialogues))
18 | final_list = []
19 | for item in dialogues:
20 | for i in item:
21 | final_list.append(i)
22 | except Exception as e:
23 | print(f"Error extracting dialogues: {str(e)}")
24 |
25 | return final_list
26 |
27 |
28 | def process_jsonl_file(input_file, output_file):
29 | result_data = []
30 |
31 | try:
32 | with open(input_file, "r", encoding="utf-8") as file:
33 | #Read JSONL file line by line
34 | for line in file:
35 | json_data = json.loads(line)
36 |
37 | # Check whether the JSON contains the "model_answer" field
38 | if "model_answer" in json_data:
39 | model_answer = json_data["model_answer"]
40 |
41 | # Extract conversations and add to result list
42 | dialogues = extract_dialogues_from_model_answer(model_answer)
43 | if not len(dialogues):
44 | continue
45 | result_data.append(dialogues)
46 |
47 | except Exception as e:
48 | print(f"Error processing JSONL file: {str(e)}")
49 |
50 | # 将结果保存为JSON文件
51 | with open(output_file, "w", encoding="utf-8") as output_file:
52 | json.dump(result_data, output_file, ensure_ascii=False, indent=2)
53 |
54 |
55 | # 用法示例
56 | input_file_path = "./data/3.patients_en_aftgpt.jsonl"
57 | output_file_path = "./data/4.patients_en_aftgpt.json"
58 |
59 | process_jsonl_file(input_file_path, output_file_path)
60 |
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_gemma.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_huatuo2.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | if len(question) + real_question_len < 4096:
149 | question += few_shot_prompt.format(**sample)
150 | question += real_question
151 | if len(question)>4096:
152 | continue
153 | if debug:
154 | print(question)
155 | debug=False
156 |
157 | tmp_dict['source_question'] = item['question']
158 | tmp_dict['source_option'] = item['options']
159 | tmp_dict['question'] = question
160 | tmp_dict['answer'] = item['answer'][1]
161 | tmp_dict['source'] = item['source']
162 | data_final.append(tmp_dict)
163 |
164 | with open(args.save_path, 'w', encoding='utf-8') as file:
165 | json.dump(data_final, file, ensure_ascii=False, indent=2)
166 |
167 |
168 | if __name__ == '__main__':
169 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
170 |
171 | # Model Args
172 | parser.add_argument('--save_path', default='', type=str)
173 | parser.add_argument('--data_path', default='', type=str)
174 | parser.add_argument('--few_shot', default='', type=int)
175 | args = parser.parse_args()
176 |
177 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_llama.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_meditron.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 |
155 | if debug:
156 | print(question)
157 | debug=False
158 |
159 | tmp_dict['source_question'] = item['question']
160 | tmp_dict['source_option'] = item['options']
161 | tmp_dict['question'] = question
162 | tmp_dict['answer'] = item['answer'][1]
163 | tmp_dict['source'] = item['source']
164 | data_final.append(tmp_dict)
165 |
166 | with open(args.save_path, 'w', encoding='utf-8') as file:
167 | json.dump(data_final, file, ensure_ascii=False, indent=2)
168 |
169 |
170 | if __name__ == '__main__':
171 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
172 |
173 | # Model Args
174 | parser.add_argument('--save_path', default='', type=str)
175 | parser.add_argument('--data_path', default='', type=str)
176 | parser.add_argument('--few_shot', default='', type=int)
177 | args = parser.parse_args()
178 |
179 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_mistral.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_qwen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.<|endoftext|>
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.<|endoftext|>
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.<|endoftext|>
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.<|endoftext|>
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.<|endoftext|>
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.<|endoftext|>
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.<|endoftext|>
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_yi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.<|endoftext|>
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.<|endoftext|>
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.<|endoftext|>
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.<|endoftext|>
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.<|endoftext|>
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.<|endoftext|>
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.<|endoftext|>
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/process/prepare/data_process_test_zephyr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8 |
9 | question_prompt_en_choice_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
10 | Question: {question}
11 | Options:
12 | {options}
13 | Assistant:The correct answer is {answer}.
14 | """
15 | question_prompt_en_choice = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to D.
16 | Question: {question}
17 | Options:
18 | {options}
19 | Assistant:"""
20 |
21 | question_prompt_en_pubmed_shot = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
22 | Context: {context}
23 | Question: {question}
24 | Options:
25 | {options}
26 | Assistant:The correct answer is {answer}.
27 | """
28 |
29 | question_prompt_en_pubmed = """User:You are a medical doctor answering real-world medical exam questions. Select one correct answer from A to C. Choose ‘yes’ or ‘no’ if the evidence in the context supports a definitive answer. Choose ‘maybe’ if the evidence in the context does not support a definitive answer.
30 | Context: {context}
31 | Question: {question}
32 | Options:
33 | {options}
34 | Assistant:"""
35 |
36 |
37 | question_prompt_zh_choice_shot = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
38 | 问题: {question}
39 | 选项:
40 | {options}
41 | Assistant:正确答案是{answer}.
42 | """
43 | question_prompt_zh_choice = """User:您是一名医生,正在回答现实世界的医学考试问题。请从A到D中选择一个正确答案。
44 | 问题: {question}
45 | 选项:
46 | {options}
47 | Assistant:"""
48 |
49 |
50 | question_prompt_es_choice_shot = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
51 | pregunta: {question}
52 | Opciones:
53 | {options}
54 | Assistant:La respuesta correcta es {answer}.
55 | """
56 | question_prompt_es_choice = """User:Usted es un médico que responde preguntas de exámenes médicos del mundo real. Elija una respuesta correcta de la A a la D.
57 | pregunta: {question}
58 | Opciones:
59 | {options}
60 | Assistant:"""
61 |
62 | question_prompt_fr_choice_shot = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
63 | question: {question}
64 | Possibilités:
65 | {options}
66 | Assistant:La bonne réponse est {answer}.
67 | """
68 | question_prompt_fr_choice = """User:Vous êtes un médecin et répondez à des questions d'examen médical du monde réel. Veuillez choisir une bonne réponse de A à E.
69 | question: {question}
70 | Possibilités:
71 | {options}
72 | Assistant:"""
73 |
74 | question_prompt_hi_choice_shot = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
75 | सवाल: {question}
76 | विकल्प:
77 | {options}
78 | Assistant:सही उत्तर है{answer}.
79 | """
80 | question_prompt_hi_choice = """User:आप एक डॉक्टर हैं जो वास्तविक दुनिया की मेडिकल परीक्षा के सवालों का जवाब दे रहे हैं। कृपया A से D तक सही उत्तर चुनें।
81 | सवाल: {question}
82 | विकल्प:
83 | {options}
84 | Assistant:"""
85 |
86 | question_prompt_ar_choice_shot = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
87 | سؤال: {question}
88 | خيارات:
89 | {options}
90 | Assistant:{answer}الإجابة الصحيحة هي.
91 | """
92 | question_prompt_ar_choice = """User:أنت طبيب يجيب على أسئلة الفحص الطبي في العالم الحقيقي. الرجاء اختيار الإجابة الصحيحة من أ إلى د.
93 | سؤال: {question}
94 | خيارات:
95 | {options}
96 | Assistant:"""
97 |
98 |
99 |
100 | def preprocess(args):
101 | data_final = []
102 | with open(args.data_path, 'r') as file:
103 | data = json.load(file)
104 | grouped_items = {}
105 | for item in data:
106 | source = item.get("source")
107 | if source not in grouped_items:
108 | grouped_items[source] = []
109 | grouped_items[source].append(item)
110 |
111 | for source, items in grouped_items.items():
112 | debug = True
113 | print(f'*********************{source}****************************')
114 | if source in ['cmb-single', 'cmexam', 'cmmlu-medical', 'medqa-mcmle']:
115 | few_shot_prompt = question_prompt_zh_choice_shot
116 | question_prompt = question_prompt_en_choice
117 | elif source in ['medmcqa', 'medqa-usmle', 'mmlu-medical']:
118 | few_shot_prompt = question_prompt_en_choice_shot
119 | question_prompt = question_prompt_en_choice
120 | elif source in ['headqa']:
121 | few_shot_prompt = question_prompt_es_choice_shot
122 | question_prompt = question_prompt_es_choice
123 | elif source in ['frenchmedmcqa']:
124 | few_shot_prompt = question_prompt_fr_choice_shot
125 | question_prompt = question_prompt_fr_choice
126 | elif source in ['mmlu-medical-ar']:
127 | few_shot_prompt = question_prompt_ar_choice_shot
128 | question_prompt = question_prompt_ar_choice
129 | elif source in ['mmlu-medical-hi']:
130 | few_shot_prompt = question_prompt_hi_choice_shot
131 | question_prompt = question_prompt_hi_choice
132 | else:
133 | few_shot_prompt = question_prompt_en_pubmed_shot
134 | question_prompt = question_prompt_en_pubmed
135 |
136 | for item in items:
137 | random_samples = random.sample(items, args.few_shot+1)
138 | question = ''
139 | tmp_dict = {}
140 | # in case item in random_samples
141 | if item in random_samples:
142 | random_samples.remove(item)
143 | else:
144 | random_samples = random_samples[:-1]
145 | real_question = question_prompt.format(**item)
146 | real_question_len = len(real_question)
147 | for sample in random_samples:
148 | sample = few_shot_prompt.format(**sample)
149 | if len(question) + real_question_len + len(sample) < 4096:
150 | question += sample
151 | question += real_question
152 | if len(question)>4096:
153 | continue
154 | if debug:
155 | print(question)
156 | debug=False
157 |
158 | tmp_dict['source_question'] = item['question']
159 | tmp_dict['source_option'] = item['options']
160 | tmp_dict['question'] = question
161 | tmp_dict['answer'] = item['answer'][1]
162 | tmp_dict['source'] = item['source']
163 | data_final.append(tmp_dict)
164 |
165 | with open(args.save_path, 'w', encoding='utf-8') as file:
166 | json.dump(data_final, file, ensure_ascii=False, indent=2)
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser(description='Args of Data Preprocess')
171 |
172 | # Model Args
173 | parser.add_argument('--save_path', default='', type=str)
174 | parser.add_argument('--data_path', default='', type=str)
175 | parser.add_argument('--few_shot', default='', type=int)
176 | args = parser.parse_args()
177 |
178 | preprocess(args)
--------------------------------------------------------------------------------
/src/proxy-tuning/eval/apollodata/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import pandas as pd
5 | import numpy as np
6 | import json
7 | from eval.utils import (
8 | ensure_dir,
9 | generate_completions,
10 | load_hf_lm_and_tokenizer,
11 | load_mexperts_model_and_tokenizer,
12 | )
13 | from transformers import AutoConfig
14 |
15 | def main(args):
16 | ensure_dir(args.save_dir)
17 |
18 | if args.model_name_or_path:
19 | print("Loading model and tokenizer...")
20 | model, tokenizer = load_hf_lm_and_tokenizer(
21 | model_name_or_path=args.model_name_or_path,
22 | tokenizer_name_or_path=args.tokenizer_name_or_path,
23 | load_in_8bit=args.load_in_8bit,
24 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
25 | use_fast_tokenizer=not args.use_slow_tokenizer,
26 | )
27 | elif args.base_model_name_or_path:
28 | model, tokenizer = load_mexperts_model_and_tokenizer(
29 | args.base_model_name_or_path,
30 | args.expert_model_name_or_path,
31 | args.antiexpert_model_name_or_path,
32 | model_type=args.model_type,
33 | load_in_8bit=args.load_in_8bit,
34 | use_fast_tokenizer=not args.use_slow_tokenizer,
35 | )
36 |
37 | # use dev set because test set answers are hidden
38 | test_df = pd.read_json(os.path.join(args.data_dir, "test.json"))
39 |
40 | # Create prompts
41 | prompts = []
42 | for i, row in test_df.iterrows():
43 | prompts.append({'question': row["question"], 'source': row["source"], 'answer': row["answer"]})
44 |
45 | new_line_token = tokenizer.encode("\n\n", add_special_tokens=False)[-1]
46 | outputs = generate_completions(
47 | model,
48 | tokenizer,
49 | prompts,
50 | batch_size=args.eval_batch_size,
51 | do_sample=False,
52 | max_new_tokens=20,
53 | stop_id_sequences=[[new_line_token]],
54 | )
55 |
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument(
60 | "--data_dir",
61 | type=str,
62 | default=None,
63 | )
64 | parser.add_argument(
65 | "--save_dir",
66 | type=str,
67 | default=None,
68 | )
69 | parser.add_argument(
70 | "--model_name_or_path",
71 | type=str,
72 | default=None,
73 | help="if specified, we will load the model to generate the predictions."
74 | )
75 | parser.add_argument(
76 | "--tokenizer_name_or_path",
77 | type=str,
78 | default=None,
79 | help="if specified, we will load the tokenizer from here."
80 | )
81 | parser.add_argument(
82 | "--use_slow_tokenizer",
83 | action="store_true",
84 | help="If given, we will use the slow tokenizer."
85 | )
86 | parser.add_argument(
87 | "--max_examples",
88 | type=int,
89 | help="if specified, a maximum of max_examples for evaluation"
90 | )
91 | parser.add_argument(
92 | "--eval_batch_size",
93 | type=int,
94 | default=1,
95 | help="batch size for evaluation."
96 | )
97 | parser.add_argument(
98 | "--load_in_8bit",
99 | action="store_true",
100 | help="load model in 8bit mode, which will reduce memory and speed up inference."
101 | )
102 | parser.add_argument(
103 | "--base_model_name_or_path",
104 | type=str,
105 | default=None,
106 | )
107 | parser.add_argument(
108 | "--expert_model_name_or_path",
109 | type=str,
110 | default=None,
111 | )
112 | parser.add_argument(
113 | "--antiexpert_model_name_or_path",
114 | type=str,
115 | default=None,
116 | )
117 | parser.add_argument(
118 | "--output_path",
119 | type=str,
120 | default='../../outputs-qwen7b.jsonl',
121 | )
122 |
123 | args = parser.parse_args()
124 | model_config = AutoConfig.from_pretrained(args.base_model_name_or_path,trust_remote_code=True)
125 | args.model_type = model_config.model_type
126 | main(args)
127 |
--------------------------------------------------------------------------------
/src/proxy-tuning/scripts/eval/proxy_tuning.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
2 |
3 | # Evaluating apollodata with qwen7b expert
4 | size=14
5 | echo "Results dir: results/apollodata"
6 | CUDA_LAUNCH_BLOCKING=1 python -m eval.apollodata.run_eval \
7 | --data_dir data/eval/apollodata/ \
8 | --save_dir results/apollodata \
9 | --base_model_name_or_path /your_model_path/Qwen-7B/ \
10 | --expert_model_name_or_path /your_model_path/Apollo-1.8B/ \
11 | --antiexpert_model_name_or_path /your_model_path/Qwen-1.5-1.8B/ \
12 | --output_path ../qwen7b.jsonl \
13 | --eval_batch_size 1
14 |
--------------------------------------------------------------------------------
/src/sft/training_config/zero.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_clipping: 1.0
4 | offload_optimizer_device: none
5 | offload_param_device: none
6 | zero3_init_flag: False
7 | zero3_save_16bit_model: true
8 | zero_stage: 1
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'yes'
11 | main_training_function: main
12 | mixed_precision: bf16
13 | num_machines: 1
14 | num_processes: 8
15 | rdzv_backend: static
16 | same_network: true
17 | use_cpu: true
--------------------------------------------------------------------------------
/src/sft/training_config/zero_multi.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_clipping: 1.0
4 | offload_optimizer_device: none
5 | offload_param_device: none
6 | zero3_init_flag: False
7 | zero3_save_16bit_model: true
8 | zero_stage: 3
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'yes'
11 | main_training_function: main
12 | mixed_precision: bf16
13 | num_machines: 8
14 | num_processes: 40
15 | rdzv_backend: static
16 | same_network: true
17 | use_cpu: False
--------------------------------------------------------------------------------
/utils/check.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 12,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "None\n",
13 | "0\n",
14 | "\n",
15 | "\n"
16 | ]
17 | }
18 | ],
19 | "source": [
20 | "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel\n",
21 | "import transformers\n",
22 | "\n",
23 | "model_path = '/your_data_path/PMC_LLaMA_7B'\n",
24 | "tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)\n",
25 | "\n",
26 | "# string = 'hello'\n",
27 | "# ids= tokenizer.encode(string)\n",
28 | "# print(ids)\n",
29 | "# for id in ids:\n",
30 | "# print(tokenizer.decode(id))\n",
31 | "print(tokenizer.pad_token_id)\n",
32 | "print(tokenizer.eos_token_id)\n",
33 | "# print(tokenizer.decode(tokenizer.pad_token_id))\n",
34 | "# print(tokenizer.decode(0))\n",
35 | "# print(tokenizer.decode(1))\n",
36 | "\n"
37 | ]
38 | }
39 | ],
40 | "metadata": {
41 | "kernelspec": {
42 | "display_name": "Python 3 (ipykernel)",
43 | "language": "python",
44 | "name": "python3"
45 | },
46 | "language_info": {
47 | "codemirror_mode": {
48 | "name": "ipython",
49 | "version": 3
50 | },
51 | "file_extension": ".py",
52 | "mimetype": "text/x-python",
53 | "name": "python",
54 | "nbconvert_exporter": "python",
55 | "pygments_lexer": "ipython3",
56 | "version": "3.10.13"
57 | }
58 | },
59 | "nbformat": 4,
60 | "nbformat_minor": 4
61 | }
62 |
--------------------------------------------------------------------------------
/utils/kill.sh:
--------------------------------------------------------------------------------
1 | pkill -f ./src/sft/train_yi_resume_val.py
2 | pkill -f eval_qwen.py
--------------------------------------------------------------------------------