├── .DS_Store ├── README.md ├── SFT ├── README.md ├── code-gemma-7B-it.yaml ├── gemma-2-9b-it.yaml ├── gemma-7b-it.yaml ├── mistral-v0.3.yaml └── sft_configs │ ├── sft_ds2.json │ └── sft_ds3.json ├── alignment_algorithms ├── README.md ├── dpo.py ├── dpo_trainer.py ├── kto_trainer.py ├── run_dpo.py ├── run_dpo.sh ├── run_kto.py ├── run_kto.sh └── training_configs │ ├── zero2_pf.yaml │ └── zero3_pf.yaml ├── assets └── main_result.png ├── inference ├── .DS_Store ├── README.md ├── data │ ├── gsm8k │ │ ├── test.jsonl │ │ └── train.jsonl │ └── math │ │ ├── test.jsonl │ │ └── train.jsonl ├── eval │ ├── evaluate.py │ └── grader.py ├── infer_data │ ├── annotate_data.py │ ├── get_dpo_dataset.py │ └── infer_eval.py ├── scripts │ ├── eval.sh │ ├── infer.sh │ ├── iter_infer_to_collect_data.sh │ └── register_server.sh └── utils │ ├── annotate_data.py │ ├── data_loader.py │ ├── filter_data.py │ ├── parser.py │ ├── python_executor.py │ └── utils.py └── useful_codes ├── annotate_data.py ├── interpolate_model.py ├── merge.py ├── push_model.py └── set_padding_token.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | Building Math Agents with Multi-Turn Iterative Preference Learning 4 |

5 | 6 |

7 | TL;DL: this is the repo for "Building Math Agents with Multi-Turn Iterative Preference Learning" 8 |

9 | 10 | We consider the math problem solving with python interpreter, which means that the model can write a python code and ask the external environmnet to execute and receive the excutaion result, before the LLM makes its next decision. 11 | 12 |

13 | 14 |
15 | Figure 1: Main evaluation results on the MATH and GSK8K datasets. 16 |

17 | 18 | ## Structure 19 | 20 | The main pipeline is divided into three steps: 21 | 22 | 23 | - [`SFT`](./SFT/) to train the SFT model. 24 | - [`Inference`](./inference/) to generate new data and evaluate the model. 25 | - [`Multi-turn Alignment Algorithms`](./alignment_algorithms/) to conduct the multi-turn DPO/KTO training. 26 | 27 | 28 | It is recommended to have three separate environments for **sft**, **inference**, and **alignment_train**. Please refer to the corresponding part of this project for the detailed installation instruction. 29 | 30 | ## Collection 31 | 32 | - [SFT Dataset: RLHF4MATH/SFT_510K](https://huggingface.co/datasets/RLHF4MATH/SFT_510K), which is a subset of nvidia/OpenMathInstruct-1 33 | - [Prompt](https://huggingface.co/datasets/RLHF4MATH/prompt_iter1): RLHF4MATH/prompt_iter1, RLHF4MATH/prompt_iter2, RLHF4MATH/prompt_iter3 34 | - [SFT Model: RLHF4MATH/Gemma-7B-it-SFT3epoch](https://huggingface.co/RLHF4MATH/Gemma-7B-it-SFT3epoch) 35 | - [Aligned Model: RLHF4MATH/Gemma-7B-it-M-DPO](https://huggingface.co/RLHF4MATH/Gemma-7B-it-M-DPO) 36 | 37 | ## Acknowledgement 38 | 39 | The authors would like to thank the great open-source communities, including the Huggingface TRL team, Axolotl team, and Tora project for sharing the models, and codes. 40 | 41 | ## Citation 42 | 43 | If you find the content of this repo useful, please consider cite it as follows: 44 | 45 | ```bibtex 46 | @article{xiong2024building, 47 | title={Building Math Agents with Multi-Turn Iterative Preference Learning}, 48 | author={Xiong, Wei and Shi, Chengshuai and Shen, Jiaming and Rosenberg, Aviv and Qin, Zhen and Calandriello, Daniele and Khalman, Misha and Joshi, Rishabh and Piot, Bilal and Saleh, Mohammad and others}, 49 | journal={arXiv preprint arXiv:2409.02392}, 50 | year={2024} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /SFT/README.md: -------------------------------------------------------------------------------- 1 | # SFT for Multi-round Chat 2 | 3 | We conduct supervised fine-tuning here for multi-turn chat. The data is supposed to be processed into the standard format (see [RLHF4MATH/SFT_510K](https://huggingface.co/datasets/RLHF4MATH/SFT_510K)) for an example). 4 | 5 | We formulate the problem into a multi-turn chat format where the ``user'' prompts the assistant (LLM) to solve a math problem and provides the external feedback. An example is as follows. 6 | 7 | - **User**: Nine adults went to a play with seven children. Adult tickets are $11 each and children's tickets are $7 each. How many dollars more did the adults' tickets cost in total than the children's tickets in total?" 8 | - **Assistant**: Let's solve this problem using Python code.\n```python\nadult_price = 11\nchild_price = 7\nadult_count = 9\nchild_count = 7\nadult_ticket_total = adult_price * adult_count\nchild_ticket_total = child_price * child_count\nadult_ticket_total - child_ticket_total\n``` 9 | - **User**: ```output\n50\n``` 10 | - **Assistant**: Thus the adult's ticket total was \\boxed{50} dollars more than the children's tickets total. 11 | 12 | We will mask all user turns so that they do not contribute to the gradient. 13 | 14 | 15 | ## Installation instructions 16 | 17 | Before starting, please make sure your linux machine has [nvidia-cuda-toolkit](https://developer.nvidia.com/cuda-toolkit) installed. 18 | 19 | ```shell 20 | sudo apt install nvidia-driver-530 # (530 for 12.1+) 21 | wget https://repo.anaconda.com/archive/Anaconda3-2023.07-0-Linux-x86_64.sh 22 | bash Anaconda3-2023.07-0-Linux-x86_64.sh # install conda 23 | conda install nvidia/label/cuda-12.2.0::cuda-nvcc 24 | ``` 25 | 26 | Now we set up the python environment. 27 | 28 | ```shell 29 | conda create -n sft python=3.10.9 30 | conda activate sft 31 | 32 | # The test cuda version is 12.1, 12.2. You may need to update the torch version based on your cuda version... 33 | pip3 install torch==2.1.2 torchvision torchaudio 34 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 35 | 36 | ## Get axolotl for general model 37 | git clone https://github.com/OpenAccess-AI-Collective/axolotl 38 | cd axolotl 39 | git checkout 55cc214c767741e83ee7b346e5e13e6c03b7b9fa 40 | pip install -e . 41 | 42 | ## Get FastChat 43 | git clone https://github.com/lm-sys/FastChat.git 44 | cd FastChat 45 | pip install -e . 46 | 47 | pip install deepspeed 48 | 49 | # You also need to install wandb to record the training and log in with the huggingface accout to access Gemma. 50 | 51 | pip install wandb 52 | wandb login 53 | 54 | huggingface-cli login 55 | ``` 56 | ## Running the Code 57 | 58 | Running the code with Gemma. 59 | 60 | ```shell 61 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun --nproc_per_node 8 --master_port 20001 -m axolotl.cli.train gemma-7b-it.yaml 62 | ``` 63 | 64 | You can also modify the learning rate, batch size, output_path.. with either command or modify the ScriptArguments in the gemma-7b-it.yml 65 | 66 | If you encounter out-of-memory issue. Running the code with Gemma-7b-it with deepspeed stage 3 and gradient checkpoint (set in the config). 67 | 68 | ```shell 69 | CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node 4 --master_port 20001 -m axolotl.cli.train gemma-2b-it.yml --deepspeed ./sft_configs/sft_ds3.json 70 | ``` 71 | 72 | **REMARK: note that with deepspeed stage 3, the final mode saving does not work normally. We set the store strategy as epoch so we can store a normal model just before we finish the training for one epoch. If you modify the store stragety, you should set the save_every_steps as the total number of training steps - 1 so that the trainer will save a model for you just before finishing the training.** 73 | 74 | 75 | Finally, for the models without an official padding token (like Mistral), you may need to set the padding token by ../useful_codes/prepare_model.py first. 76 | 77 | -------------------------------------------------------------------------------- /SFT/code-gemma-7B-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: google/codegemma-1.1-7b-it 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHF4MATH/SFT_510K 11 | conversation: gemma 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 50 17 | val_set_size: 0.0 18 | output_dir: ./models/gemma-7b-it_bz64_lr5e-6_pack 19 | wandb_project: huggingface 20 | #wandb_entity: domain-generalization 21 | wandb_watch: 22 | wandb_name: "codegemma-7b-it_bs64_lr5e-6_sft_pack4096" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | 43 | 44 | 45 | gradient_checkpointing: true 46 | 47 | #warmup_ratio: 0.1 48 | gradient_accumulation_steps: 2 49 | micro_batch_size: 4 50 | num_epochs: 4 51 | optimizer: paged_adamw_32bit 52 | #adiamw_torch_fused 53 | lr_scheduler: cosine 54 | learning_rate: 5.e-6 55 | 56 | 57 | weight_decay: 0.0 58 | max_grad_norm: 1.0 59 | 60 | 61 | group_by_length: false 62 | bf16: auto 63 | fp16: false 64 | tf32: true 65 | 66 | early_stopping_patience: 67 | local_rank: 68 | logging_steps: 2 69 | xformers_attention: 70 | flash_attention: true 71 | 72 | 73 | eval_steps: 74 | eval_table_size: 75 | eval_table_max_new_tokens: 76 | #save_steps: 100 77 | save_strategy: "epoch" 78 | save_total_limit: 4 79 | debug: 80 | 81 | 82 | ddp: #true 83 | deepspeed: #deepspeed/zero1.json # multi-gpu only 84 | 85 | fsdp: 86 | fsdp_config: 87 | special_tokens: 88 | -------------------------------------------------------------------------------- /SFT/gemma-2-9b-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: google/gemma-2-9b-it 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHF4MATH/SFT_510K 11 | conversation: gemma 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 100 17 | val_set_size: 0.0 18 | output_dir: ./models/gemma_9b_bz64_lr5e6_pack4096 19 | wandb_project: huggingface 20 | #wandb_entity: sft 21 | wandb_watch: 22 | wandb_name: "gemma_9b_bz64_lr5e6_pack4096" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | gradient_checkpointing: true 43 | 44 | #warmup_ratio: 0.1 45 | gradient_accumulation_steps: 8 46 | micro_batch_size: 1 47 | num_epochs: 4 48 | optimizer: paged_adamw_32bit 49 | lr_scheduler: cosine 50 | learning_rate: 5e-6 51 | 52 | weight_decay: 0.0 53 | max_grad_norm: 1.0 54 | 55 | 56 | group_by_length: false 57 | bf16: auto 58 | fp16: false 59 | tf32: true 60 | 61 | early_stopping_patience: 62 | local_rank: 63 | logging_steps: 2 64 | xformers_attention: 65 | flash_attention: true 66 | 67 | 68 | eval_steps: 69 | eval_table_size: 70 | eval_table_max_new_tokens: 71 | save_steps: 9999 72 | save_strategy: "steps" 73 | save_total_limit: 4 74 | debug: 75 | 76 | 77 | ddp: #true 78 | deepspeed: #deepspeed/zero1.json # multi-gpu only 79 | 80 | fsdp: 81 | fsdp_config: 82 | special_tokens: 83 | -------------------------------------------------------------------------------- /SFT/gemma-7b-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: google/gemma-1.1-7b-it 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHF4MATH/SFT_510K 11 | conversation: gemma 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 50 17 | val_set_size: 0.0 18 | output_dir: ./pm_models/code_gemma-7b-it_bz64_lr5e-6_non_pack 19 | wandb_project: huggingface 20 | #wandb_entity: domain-generalization 21 | wandb_watch: 22 | wandb_name: "codegemma-7b-it_bs64_lr5e-6_sft_pack4096" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | 43 | 44 | 45 | gradient_checkpointing: true 46 | 47 | #warmup_ratio: 0.1 48 | gradient_accumulation_steps: 2 49 | micro_batch_size: 4 50 | num_epochs: 4 51 | optimizer: paged_adamw_32bit 52 | #adiamw_torch_fused 53 | lr_scheduler: cosine 54 | learning_rate: 5.e-6 55 | 56 | 57 | weight_decay: 0.0 58 | max_grad_norm: 1.0 59 | 60 | 61 | group_by_length: false 62 | bf16: auto 63 | fp16: false 64 | tf32: true 65 | 66 | early_stopping_patience: 67 | local_rank: 68 | logging_steps: 2 69 | xformers_attention: 70 | flash_attention: true 71 | 72 | 73 | eval_steps: 74 | eval_table_size: 75 | eval_table_max_new_tokens: 76 | #save_steps: 100 77 | save_strategy: "epoch" 78 | save_total_limit: 4 79 | debug: 80 | 81 | 82 | ddp: #true 83 | deepspeed: #deepspeed/zero1.json # multi-gpu only 84 | 85 | fsdp: 86 | fsdp_config: 87 | special_tokens: 88 | -------------------------------------------------------------------------------- /SFT/mistral-v0.3.yaml: -------------------------------------------------------------------------------- 1 | base_model: ./models/mistral_with_pad 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHF4MATH/SFT_510K 11 | conversation: mistral 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 50 17 | val_set_size: 0.0 18 | output_dir: ./models/mistral_bz64_1e5_pack4096_1200k 19 | wandb_project: huggingface 20 | #wandb_entity: sft 21 | wandb_watch: 22 | wandb_name: "mistral_bz64_1e5_pack4096_510k" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | 43 | 44 | 45 | 46 | gradient_checkpointing: true 47 | 48 | #warmup_ratio: 0.1 49 | gradient_accumulation_steps: 2 50 | micro_batch_size: 4 51 | num_epochs: 3 52 | optimizer: paged_adamw_32bit 53 | lr_scheduler: cosine 54 | learning_rate: 1.e-5 55 | 56 | weight_decay: 0.0 57 | max_grad_norm: 1.0 58 | 59 | 60 | group_by_length: false 61 | bf16: auto 62 | fp16: false 63 | tf32: true 64 | 65 | early_stopping_patience: 66 | local_rank: 67 | logging_steps: 2 68 | xformers_attention: 69 | flash_attention: true 70 | eval_steps: 71 | eval_table_size: 72 | eval_table_max_new_tokens: 73 | #save_steps: 99999 74 | save_strategy: "epoch" 75 | save_total_limit: 3 76 | debug: 77 | 78 | 79 | ddp: #true 80 | deepspeed: 81 | 82 | fsdp: 83 | fsdp_config: 84 | special_tokens: 85 | 86 | 87 | -------------------------------------------------------------------------------- /SFT/sft_configs/sft_ds2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "contiguous_gradients": true, 8 | "overlap_comm": true 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "fp16": { 14 | "enabled": "auto", 15 | "auto_cast": false, 16 | "loss_scale": 0, 17 | "initial_scale_power": 32, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "gradient_accumulation_steps": "auto", 23 | "gradient_clipping": "auto", 24 | "train_batch_size": "auto", 25 | "train_micro_batch_size_per_gpu": "auto", 26 | "wall_clock_breakdown": false 27 | } 28 | -------------------------------------------------------------------------------- /SFT/sft_configs/sft_ds3.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 0, 7 | "reduce_bucket_size": "auto", 8 | "stage3_prefetch_bucket_size": "auto", 9 | "stage3_param_persistence_threshold": "auto", 10 | "stage3_max_live_parameters": 0, 11 | "stage3_max_reuse_distance": 0, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "bf16": { 15 | "enabled": true 16 | }, 17 | "fp16": { 18 | "enabled": "auto", 19 | "auto_cast": false, 20 | "loss_scale": 0, 21 | "initial_scale_power": 32, 22 | "loss_scale_window": 1000, 23 | "hysteresis": 2, 24 | "min_loss_scale": 1 25 | }, 26 | "gradient_accumulation_steps": "auto", 27 | "gradient_clipping": "auto", 28 | "train_batch_size": "auto", 29 | "train_micro_batch_size_per_gpu": "auto", 30 | "wall_clock_breakdown": false 31 | } 32 | -------------------------------------------------------------------------------- /alignment_algorithms/README.md: -------------------------------------------------------------------------------- 1 | # DPO/KTO Training with Multi-turn Data 2 | 3 | The implementation of DPO and KTO are adapted from open-source packages [TRL](https://github.com/huggingface/trl) and [RLHFlow](https://github.com/RLHFlow/Online-RLHF). Comparedto the original DPO/KTO, we only need to modify the mask of the samples to mask out all the external tokens. The current implementation supports Gemma and Mistral model. You can read the "get_new_mask" function in dpo_trainer or kto_trainer to get the idea and easily implement for other LLMs. 4 | 5 | 6 | ## 1 Installation instructions 7 | 8 | **Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!!** 9 | 10 | 11 | Before starting, please make sure your linux machine has nvidia-cuda-toolkit installed. See SFT part for the guidance. 12 | 13 | 14 | **Training Environment** 15 | 16 | ```sh 17 | conda create -n alignment_train python=3.10.9 18 | conda activate alignment_train 19 | 20 | git clone https://github.com/huggingface/alignment-handbook.git 21 | cd ./alignment-handbook/ 22 | git checkout d17fd7cd3b71c6a7bf7af34d8dc73135bb7ea8e9 23 | pip3 install torch==2.1.2 torchvision torchaudio 24 | python -m pip install . 25 | pip install flash-attn==2.6.3 26 | pip install accelerate==0.33.0 27 | 28 | pip install huggingface-hub==0.24.7 29 | pip install wandb 30 | wandb login 31 | huggingface-cli login 32 | ``` 33 | 34 | ## 2 Hakcing the DPO Trainer and KTO Trainer 35 | 36 | ### 2.1 Hack DPO Trainer 37 | 38 | The code is based on RLHFlow/Online-RLHF but we need to hack the trainer to implement some additional functions. We highlight the modified part with ############## MODIFICATION. 39 | 40 | ```sh 41 | # Step 1: find the original DPO trainer 42 | cd anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/ 43 | 44 | # Step 2: delete the old one 45 | rm dpo_trainer.py 46 | 47 | # Step 3: use the modified one in this repo. The following command need to be modified to use the correct address 48 | mv dpo_train/dpo_trainer.py anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py 49 | ``` 50 | 51 | ### 2.2 Hack KTO Trainer 52 | 53 | The code is based on RLHFlow/Online-RLHF but we need to hack the KTO trainer to implement some additional functions. We highlight the modified part with ############## MODIFICATION. 54 | 55 | ```sh 56 | # Step 1: find the original DPO trainer 57 | cd anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/ 58 | 59 | # Step 2: delete the old one 60 | rm kto_trainer.py 61 | 62 | # Step 3: use the modified one in this repo. The following command need to be modified to use the correct address 63 | mv kto_train/kto_trainer.py anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/kto_trainer.py 64 | 65 | # Step 4: modify the KTO config according to your GPU resource. 66 | vim ./trl/trainer/kto_config.py 67 | max_length: Optional[int] = 2048 68 | max_prompt_length: Optional[int] = 1024 69 | max_completion_length: Optional[int] = 2048 70 | ``` 71 | 72 | ### 2.3 Fix Import Error 73 | 74 | For transformers > 4.38.2, you will encounter an import issue related to the following function in anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/core.py. You can comment on the import from transformers and copy and paste the following hard code version in core.py. 75 | 76 | ```python 77 | def top_k_top_p_filtering( 78 | logits: torch.FloatTensor, 79 | top_k: int = 0 80 | top_p: float = 1.0, 81 | filter_value: float = -float("Inf"), 82 | min_tokens_to_keep: int = 1, 83 | ) -> torch.FloatTensor: 84 | """ 85 | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. 86 | 87 | Args: 88 | logits: logits distribution shape (batch size, vocabulary size) 89 | top_k (`int`, *optional*, defaults to 0): 90 | If > 0, only keep the top k tokens with highest probability (top-k filtering) 91 | top_p (`float`, *optional*, defaults to 1.0): 92 | If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus 93 | filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751) 94 | min_tokens_to_keep (`int`, *optional*, defaults to 1): 95 | Minimumber of tokens we keep per batch example in the output. 96 | 97 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 98 | """ 99 | 100 | if top_k > 0: 101 | logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( 102 | None, logits 103 | ) 104 | 105 | if 0 <= top_p <= 1.0: 106 | logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( 107 | None, logits 108 | ) 109 | 110 | return logits 111 | ``` 112 | 113 | 114 | ## 3 Running the Code 115 | 116 | ### 3.1 DPO 117 | Running the code before modify num_processes: 8 in ./training_configs/zero2_pf.yaml, the number 8 means that you will use 8 GPUs. Also modify the parameters, models, and datasets provided in run_dpo.py. 118 | 119 | ```shell 120 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_dpo.py ./training_configs/training.yaml 121 | 122 | ``` 123 | 124 | ### 3.2 KTO 125 | 126 | ```shell 127 | bash run_kto.sh 128 | ``` 129 | 130 | If you encounter out-of-memory issue. Running the code with Gemma-7b-it with zero3_pf.yaml. You can also reduce the max length of the data. 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /alignment_algorithms/dpo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 3 | import copy 4 | import torch 5 | import torch.nn.functional as F 6 | from datasets import Dataset 7 | # from peft import AutoPeftModelForCausalLM, LoraConfig 8 | from torch import nn 9 | from torch.nn.utils.rnn import pad_sequence 10 | from transformers import ( 11 | DataCollator, 12 | PreTrainedModel, 13 | PreTrainedTokenizerBase, 14 | TrainerCallback, 15 | TrainingArguments, 16 | ) 17 | from transformers.trainer_callback import TrainerCallback 18 | from transformers.trainer_utils import EvalLoopOutput 19 | from trl import DPOTrainer 20 | 21 | # Define and parse arguments. 22 | 23 | ############## MODIFICATION 24 | def get_new_mask(input_ids, old_labels, model='gemma'): 25 | # We mask the user turn to create new labels for Gemma model 26 | labels = copy.deepcopy(old_labels) 27 | start = False 28 | if 'gemma' in model.lower(): 29 | for j in range(len(input_ids)): 30 | if input_ids[j:j+3] == [106, 1645, 108]: 31 | start = True 32 | labels[j:j+3] = -100 33 | if input_ids[j:j+2] == [107, 108] and start: 34 | labels[j] = -100 35 | labels[j+1] = -100 36 | start = False 37 | if start: 38 | labels[j] = -100 39 | elif 'mistral' in model.lower(): 40 | for j in range(len(input_ids)): 41 | if input_ids[j] == 3: 42 | start = True 43 | input_ids[j] = -100 44 | if input_ids[j] == 4 and start: 45 | labels[j] = -100 46 | start = False 47 | if start: 48 | labels[j] = -100 49 | else: 50 | raise NotImplementedError(model) 51 | return labels 52 | ############## MODIFICATION END 53 | 54 | @dataclass 55 | class PreferenceDataCollatorWithPadding: 56 | tokenizer: PreTrainedTokenizerBase 57 | model: Optional[PreTrainedModel] = None 58 | padding: Union[bool, str] = True 59 | max_length: Optional[int] = None 60 | max_prompt_length: Optional[int] = None 61 | label_pad_token_id: int = -100 62 | padding_value: int = 0 63 | truncation_mode: str = "keep_end" 64 | is_encoder_decoder: Optional[bool] = False 65 | max_target_length: Optional[int] = None 66 | mask_prompt: Optional[bool] = False 67 | mask_user_turn: Optional[bool] = False 68 | model_name: Optional[str] = 'gemma' 69 | 70 | def tokenize_batch_element( 71 | self, 72 | prompt: str, 73 | chosen: str, 74 | rejected: str, 75 | ) -> Dict: 76 | """Tokenize a single batch element. 77 | 78 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 79 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 80 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 81 | 82 | We also create the labels for the chosen/rejected responses, which are of length equal to 83 | the sum of the length of the prompt and the chosen/rejected response, with 84 | label_pad_token_id for the prompt tokens. 85 | """ 86 | batch = {} 87 | 88 | if not self.is_encoder_decoder: 89 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) 90 | rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) 91 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) 92 | 93 | eos_token_id = self.tokenizer.eos_token_id 94 | # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) 95 | eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] 96 | # attention mask these indices to eos_token_id 97 | if self.mask_prompt: 98 | new_attention_mask = [0 for i, p in enumerate(prompt_tokens["attention_mask"])] 99 | print("I mask the prompt") 100 | else: 101 | new_attention_mask = [ 102 | 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) 103 | ] 104 | prompt_tokens["attention_mask"] = new_attention_mask 105 | 106 | # do the same for chosen and rejected 107 | eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] 108 | new_attention_mask_c = [ 109 | 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) 110 | ] 111 | chosen_tokens["attention_mask"] = new_attention_mask_c 112 | 113 | eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] 114 | new_attention_mask_r = [ 115 | 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) 116 | ] 117 | rejected_tokens["attention_mask"] = new_attention_mask_r 118 | 119 | # add EOS token to end of prompt 120 | 121 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 122 | chosen_tokens["attention_mask"].append(1) 123 | 124 | rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) 125 | rejected_tokens["attention_mask"].append(1) 126 | 127 | longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) 128 | 129 | # if combined sequence is too long, truncate the prompt 130 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 131 | if self.truncation_mode == "keep_start": 132 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} 133 | elif self.truncation_mode == "keep_end": 134 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} 135 | else: 136 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") 137 | 138 | # if that's still too long, truncate the response 139 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 140 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} 141 | rejected_tokens = { 142 | k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() 143 | } 144 | 145 | # Create labels 146 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 147 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} 148 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] 149 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 150 | prompt_tokens["input_ids"] 151 | ) 152 | rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] 153 | rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 154 | prompt_tokens["input_ids"] 155 | ) 156 | ############ MODIFICATION 157 | if self.mask_user_turn: 158 | new_chosen_sequence_labels= get_new_mask(chosen_sequence_tokens['input_ids'], chosen_sequence_tokens['labels'], model=self.model_name) 159 | new_rej_sequence_labels = get_new_mask(rejected_sequence_tokens['input_ids'], rejected_sequence_tokens['labels'], model=self.model_name) 160 | chosen_sequence_tokens["labels"] = new_chosen_sequence_labels 161 | rejected_sequence_tokens["labels"] = new_rej_sequence_labels 162 | ############ 163 | for k, toks in { 164 | "chosen": chosen_sequence_tokens, 165 | "rejected": rejected_sequence_tokens, 166 | "prompt": prompt_tokens, 167 | }.items(): 168 | for type_key, tokens in toks.items(): 169 | if type_key == "token_type_ids": 170 | continue 171 | batch[f"{k}_{type_key}"] = tokens 172 | 173 | else: 174 | raise NotImplementedError 175 | 176 | batch["prompt"] = prompt 177 | batch["chosen"] = prompt + chosen 178 | batch["rejected"] = prompt + rejected 179 | batch["chosen_response_only"] = chosen 180 | batch["rejected_response_only"] = rejected 181 | return batch 182 | 183 | def collate(self, batch): 184 | # first, pad everything to the same length 185 | padded_batch = {} 186 | for k in batch[0].keys(): 187 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): 188 | if self.is_encoder_decoder: 189 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 190 | 191 | if (k.startswith("prompt")) and (k.endswith("input_ids")): 192 | padding_value = self.tokenizer.pad_token_id 193 | elif k.endswith("_attention_mask"): 194 | padding_value = 0 195 | elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): 196 | padding_value = self.label_pad_token_id 197 | else: 198 | raise ValueError(f"Unexpected key in batch '{k}'") 199 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 200 | else: 201 | # adapted from https://stackoverflow.com/questions/73256206 202 | if "prompt" in k: 203 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 204 | else: 205 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 206 | if k.endswith("_input_ids"): 207 | padding_value = self.tokenizer.pad_token_id 208 | elif k.endswith("_labels"): 209 | padding_value = self.label_pad_token_id 210 | elif k.endswith("_attention_mask"): 211 | padding_value = self.padding_value 212 | else: 213 | raise ValueError(f"Unexpected key in batch '{k}'") 214 | 215 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 216 | # for the prompt, flip back so padding is on left side 217 | if "prompt" in k: 218 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 219 | else: 220 | padded_batch[k] = [ex[k] for ex in batch] 221 | 222 | return padded_batch 223 | 224 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 225 | tokenized_batch = [] 226 | 227 | for feature in features: 228 | prompt = feature["prompt"] 229 | chosen = feature["chosen"] 230 | rejected = feature["rejected"] 231 | 232 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected) 233 | batch_element["margin"] = feature["margin"] 234 | tokenized_batch.append(batch_element) 235 | 236 | # return collated batch 237 | return self.collate(tokenized_batch) 238 | 239 | 240 | class PreferenceTrainer(DPOTrainer): 241 | def __init__( 242 | self, 243 | model: Union[PreTrainedModel, nn.Module] = None, 244 | ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, 245 | beta: float = 0.1, 246 | loss_type: Literal["sigmoid", "hinge", "cross_entropy", "kl", "rev_kl", "raft"] = "rev_kl", 247 | args: TrainingArguments = None, 248 | data_collator: Optional[DataCollator] = None, 249 | label_pad_token_id: int = -100, 250 | padding_value: int = 0, 251 | truncation_mode: str = "keep_end", 252 | train_dataset: Optional[Dataset] = None, 253 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 254 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 255 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 256 | callbacks: Optional[List[TrainerCallback]] = None, 257 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 258 | None, 259 | None, 260 | ), 261 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 262 | max_length: Optional[int] = None, 263 | max_prompt_length: Optional[int] = None, 264 | max_target_length: Optional[int] = None, 265 | peft_config: Optional[Dict] = None, 266 | is_encoder_decoder: Optional[bool] = None, 267 | disable_dropout: bool = True, 268 | generate_during_eval: bool = False, 269 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, 270 | mask_prompt: Optional[bool] = False, 271 | len_penalty: float = 0, 272 | nll_coefficient: float = 0, 273 | masking_user_turn: bool = True, 274 | ): 275 | ################# MODIFICATION 276 | # if nll_coefficient > 0, then dpo_loss + self.nll_coefficient * nll_loss on preferred response 277 | self.nll_coefficient = nll_coefficient 278 | # whether we mask the user turn or not for implementing m-dpo, if not, it is a regular single-turn DPO 279 | self.masking_user_turn = masking_user_turn 280 | ############## MODIFICATION END 281 | 282 | 283 | if data_collator is None: 284 | data_collator = PreferenceDataCollatorWithPadding( 285 | tokenizer, 286 | max_length=max_length, 287 | max_prompt_length=max_prompt_length, 288 | label_pad_token_id=label_pad_token_id, 289 | padding_value=padding_value, 290 | truncation_mode=truncation_mode, 291 | is_encoder_decoder=False, 292 | max_target_length=max_target_length, 293 | mask_prompt=mask_prompt, 294 | mask_user_turn = self.masking_user_turn, 295 | ) 296 | super().__init__( 297 | model=model, 298 | ref_model=ref_model, 299 | beta=beta, 300 | loss_type=loss_type, 301 | args=args, 302 | data_collator=data_collator, 303 | label_pad_token_id=label_pad_token_id, 304 | padding_value=padding_value, 305 | truncation_mode=truncation_mode, 306 | train_dataset=train_dataset, 307 | eval_dataset=eval_dataset, 308 | tokenizer=tokenizer, 309 | model_init=model_init, 310 | callbacks=callbacks, 311 | optimizers=optimizers, 312 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 313 | max_length=max_length, 314 | max_prompt_length=max_prompt_length, 315 | max_target_length=max_target_length, 316 | peft_config=peft_config, 317 | is_encoder_decoder=is_encoder_decoder, 318 | disable_dropout=disable_dropout, 319 | generate_during_eval=generate_during_eval, 320 | compute_metrics=compute_metrics, 321 | ) 322 | self.use_dpo_data_collator = True 323 | self.len_penalty = len_penalty 324 | 325 | def dpo_loss( 326 | self, 327 | policy_chosen_logps: torch.FloatTensor, 328 | policy_rejected_logps: torch.FloatTensor, 329 | reference_chosen_logps: torch.FloatTensor, 330 | reference_rejected_logps: torch.FloatTensor, 331 | reference_free: bool = False, 332 | margin: Optional[torch.FloatTensor] = None, 333 | len_penalty: float = 0, 334 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 335 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 336 | 337 | Args: 338 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 339 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 340 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 341 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 342 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 343 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 344 | 345 | Returns: 346 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 347 | The losses tensor contains the DPO loss for each example in the batch. 348 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 349 | """ 350 | pi_logratios = policy_chosen_logps - policy_rejected_logps 351 | ref_logratios = reference_chosen_logps - reference_rejected_logps + len_penalty 352 | 353 | if reference_free: 354 | ref_logratios = 0 355 | 356 | if self.loss_type == "sigmoid": 357 | logits = pi_logratios - ref_logratios 358 | losses = -F.logsigmoid(self.beta * logits) 359 | elif self.loss_type == "hinge": 360 | logits = pi_logratios - ref_logratios 361 | losses = torch.relu(1 - self.beta * logits) 362 | elif self.loss_type == "cross_entropy": 363 | logits = policy_chosen_logps - reference_chosen_logps 364 | losses = -F.logsigmoid(self.beta * logits) 365 | elif self.loss_type == "raft": 366 | losses = -policy_chosen_logps # F.logsigmoid(self.beta * logits) 367 | elif self.loss_type == "ipo": 368 | logits = pi_logratios - ref_logratios 369 | # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. 370 | losses = (logits - 1 / (2 * self.beta)) ** 2 371 | elif self.loss_type == "kl": 372 | logits = pi_logratios - ref_logratios 373 | p = F.sigmoid(self.beta * logits) 374 | p = torch.minimum(p, torch.ones_like(p) * 0.999) 375 | p_gt = torch.exp(margin) / (1 + torch.exp(margin) + 1e-3) 376 | losses = p * (torch.log(p) - torch.log(p_gt)) + (1 - p) * (torch.log(1 - p) - torch.log(1 - p_gt)) 377 | elif self.loss_type == "tv": 378 | logits = pi_logratios - ref_logratios 379 | p = F.sigmoid(self.beta * logits) 380 | p_gt = torch.exp(margin) / (1 + torch.exp(margin)) 381 | losses = torch.abs(p - p_gt) 382 | elif self.loss_type == "hellinger": 383 | logits = pi_logratios - ref_logratios 384 | p = F.sigmoid(self.beta * logits) 385 | p = torch.minimum(p, torch.ones_like(p) * 0.999) 386 | p_gt = torch.exp(margin) / (1 + torch.exp(margin)) 387 | losses = 0.5 * ((p**0.5 - p_gt**0.5) ** 2 + ((1 - p) ** 0.5 - (1 - p_gt) ** 0.5) ** 2) 388 | elif self.loss_type == "rev_kl": 389 | logits = pi_logratios - ref_logratios 390 | logp = F.logsigmoid(self.beta * logits) 391 | logp_neg = F.logsigmoid(-self.beta * logits) 392 | p_gt = F.sigmoid(margin) 393 | losses = -p_gt * (logp) - (1 - p_gt) * logp_neg 394 | else: 395 | raise ValueError(f"Unknown loss type: {self.loss_type}.") 396 | 397 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 398 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() 399 | 400 | return losses, chosen_rewards, rejected_rewards 401 | 402 | def get_batch_loss_metrics( 403 | self, 404 | model, 405 | batch: Dict[str, Union[List, torch.LongTensor]], 406 | train_eval: Literal["train", "eval"] = "train", 407 | ): 408 | return self.get_batch_metrics(model, batch, train_eval) 409 | 410 | def get_batch_metrics( 411 | self, 412 | model, 413 | batch: Dict[str, Union[List, torch.LongTensor]], 414 | train_eval: Literal["train", "eval"] = "train", 415 | ): 416 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" 417 | metrics = {} 418 | ( 419 | policy_chosen_logps, 420 | policy_rejected_logps, 421 | policy_chosen_logits, 422 | policy_rejected_logits, 423 | policy_nll_loss, 424 | ) = self.concatenated_forward(model, batch) 425 | with torch.no_grad(): 426 | if self.ref_model is None: 427 | with self.accelerator.unwrap_model(self.model).disable_adapter(): 428 | ( 429 | reference_chosen_logps, 430 | reference_rejected_logps, 431 | _, 432 | _, 433 | _, 434 | ) = self.concatenated_forward(self.model, batch) 435 | else: 436 | ( 437 | reference_chosen_logps, 438 | reference_rejected_logps, 439 | _, 440 | _, 441 | _, 442 | ) = self.concatenated_forward(self.ref_model, batch) 443 | if self.len_penalty > 0: 444 | chosen_len = batch["chosen_input_ids"].shape[1] * self.len_penalty 445 | rejected_len = batch["rejected_input_ids"].shape[1] * self.len_penalty 446 | len_penalty = chosen_len - rejected_len 447 | else: 448 | chosen_len = 1 449 | rejected_len = 1 450 | len_penalty = 0 451 | 452 | margin = torch.tensor(batch["margin"], dtype=policy_chosen_logps.dtype).to(self.accelerator.device) 453 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 454 | policy_chosen_logps, 455 | policy_rejected_logps, 456 | reference_chosen_logps, 457 | reference_rejected_logps, 458 | margin=margin, 459 | len_penalty=len_penalty, 460 | ) 461 | ############## MODIFICATION 462 | if self.nll_coefficient > 0: 463 | losses = losses + self.nll_coefficient * policy_nll_loss 464 | else: 465 | pass 466 | ############## MODIFICATION END 467 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 468 | 469 | prefix = "eval_" if train_eval == "eval" else "" 470 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() 471 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() 472 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() 473 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() 474 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() 475 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() 476 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() 477 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() 478 | ######## 479 | metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().cpu().mean() 480 | ######## 481 | return losses.mean(), metrics 482 | -------------------------------------------------------------------------------- /alignment_algorithms/run_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | from datasets import Dataset, load_dataset 7 | from dpo import PreferenceTrainer 8 | from transformers import ( 9 | AutoModelForCausalLM, 10 | AutoTokenizer, 11 | HfArgumentParser, 12 | TrainingArguments, 13 | ) 14 | 15 | 16 | @dataclass 17 | class ScriptArguments: 18 | """ 19 | The arguments for the DPO training script. 20 | """ 21 | 22 | # data parameters, i.e., the KL penalty in the paper 23 | beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) 24 | 25 | # training parameters 26 | model_name_or_path: Optional[str] = field( 27 | default="RLHF4MATH/Gemma-7B-it-SFT3epoch", 28 | metadata={"help": "the location of the model name or path"}, 29 | ) 30 | ref_model: Optional[str] = field( 31 | default="RLHF4MATH/Gemma-7B-it-SFT3epoch", 32 | metadata={"help": "the location of the SFT model name or path"}, 33 | ) 34 | train_dir: Optional[str] = field( 35 | default="RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs", 36 | metadata={"help": "the location of the dataset name or path"}, 37 | ) 38 | eval_dir: Optional[str] = field( 39 | default="RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs", 40 | metadata={"help": "the location of the evalset name or path"}, 41 | ) 42 | learning_rate: Optional[float] = field(default=4e-7, metadata={"help": "optimizer learning rate"}) 43 | lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) 44 | warmup_steps: Optional[int] = field(default=50, metadata={"help": "the number of warmup steps"}) 45 | weight_decay: Optional[float] = field(default=0.01, metadata={"help": "the weight decay"}) 46 | optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) 47 | 48 | per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) 49 | per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) 50 | gradient_accumulation_steps: Optional[int] = field( 51 | default=4, metadata={"help": "the number of gradient accumulation steps"} 52 | ) 53 | gradient_checkpointing: Optional[bool] = field( 54 | default=True, metadata={"help": "whether to use gradient checkpointing"} 55 | ) 56 | 57 | eos_padding: Optional[bool] = field(default=True, metadata={"help": "whether to pad with eos token"}) 58 | lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) 59 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) 60 | lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) 61 | 62 | margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"}) 63 | 64 | max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"}) 65 | max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length"}) 66 | max_steps: Optional[int] = field(default=4000, metadata={"help": "max number of training steps"}) 67 | num_train_epochs: Optional[int] = field(default=2, metadata={"help": "max number of training epochs"}) 68 | logging_steps: Optional[int] = field(default=2, metadata={"help": "the logging frequency"}) 69 | save_strategy: Optional[str] = field(default="steps", metadata={"help": "the saving strategy"}) 70 | save_steps: Optional[int] = field(default=25, metadata={"help": "the saving frequency"}) 71 | eval_steps: Optional[int] = field(default=300, metadata={"help": "the evaluation frequency"}) 72 | run_name: Optional[str] = field(default="mdpo_iter1_gemma7b_lr4e7_bz32", metadata={"help": "the run name"}) 73 | loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"}) 74 | output_dir: Optional[str] = field( 75 | default="./mdpo_iter1_gemma7b_lr4e7_bz32", metadata={"help": "the output directory"} 76 | ) 77 | log_freq: Optional[int] = field(default=2, metadata={"help": "the logging frequency"}) 78 | 79 | # instrumentation 80 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) 81 | 82 | max_training_samples: Optional[int] = field(default=-1, metadata={"help": "the maximum sample size"}) 83 | 84 | choose_type: Optional[str] = field(default="max_random", metadata={"help": "the choose type"}) 85 | 86 | report_to: Optional[str] = field( 87 | default="wandb", 88 | metadata={ 89 | "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' 90 | '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' 91 | 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' 92 | }, 93 | ) 94 | # debug argument for distributed training 95 | ignore_bias_buffers: Optional[bool] = field( 96 | default=False, 97 | metadata={ 98 | "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" 99 | "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" 100 | }, 101 | ) 102 | eot_token: Optional[str] = field(default="", metadata={"help": "the end of text token"}) 103 | mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"}) 104 | len_penalty: Optional[float] = field(default=0, metadata={"help": "the length penalty"}) 105 | 106 | masking_user_turn: Optional[bool] = field(default=True, metadata={"help": "mask user turn"}) 107 | nll_coefficient: Optional[float] = field(default=0, metadata={"help": "the coefficeint of NLL loss"}) 108 | 109 | 110 | def prepare_data( 111 | tokenizer, 112 | data_dir: str = "xxx", 113 | sanity_check: bool = False, 114 | cache_dir: str = None, 115 | num_proc=24, 116 | margin_scale=1, 117 | choose_type="random", 118 | eot_token="", 119 | length_penalty=0, 120 | ) -> Dataset: 121 | """Prepare the dataset for DPO training. The input datasets are supposed to be in the standard format with keys chosen and rejected. 122 | The margin is not used currently and may be activated later for future research. 123 | 124 | [ { "content": "If a 40-foot tree is casting a 10-foot shadow, and Andrea is casting a 15-inch shadow at the same time, how tall is Andrea in inches?", "role": "user" }, 125 | { "content": "The shadow of the tree is 10 feet which is 120 inches.\nSo let's set *tree height* = 40 feet = 40 * 12 inches\n*tree shadow* = 10 feet = 120 inches\n*Andrea's shadow* = 15 inches\nFrom the similar triangles, we can find Andrea's height.\n```python\ntree_height = 40 * 12 # tree is 40 feet which is 40 * 12 inches\ntree_shadow = 10 * 12 # tree shadow is 10 feet = 120 inches\nandrea_shadow = 15 # Andrea's shadow is 15 inches\n\n# Find Andrea's height using similar triangles\nandrea_height = andrea_shadow * (tree_height / tree_shadow)\nandrea_height\n```", "role": "assistant" }, 126 | { "content": "```output\n60.0\n```", "role": "user" }, 127 | { "content": "So Andrea is $\\boxed{60}$ inches tall.", "role": "assistant" } ] 128 | """ 129 | ds = load_dataset(data_dir, split="train") 130 | ds = ds.shuffle(seed=42) 131 | print(ds) 132 | 133 | pos = [] 134 | neg = [] 135 | prompts = [] 136 | margin = [] 137 | for sample in ds: 138 | chosen = sample["chosen"] 139 | rejected = sample["rejected"] 140 | prompt = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True) 141 | prompt2 = tokenizer.apply_chat_template([rejected[0]], tokenize=False, add_generation_prompt=True) 142 | if prompt != prompt2: 143 | continue 144 | 145 | # assert prompt == prompt2 146 | chosen_str = tokenizer.apply_chat_template(chosen, tokenize=False).replace(prompt, "") 147 | rejected_str = tokenizer.apply_chat_template(rejected, tokenize=False).replace(prompt, "") 148 | prompts.append(prompt) 149 | pos.append(chosen_str) 150 | neg.append(rejected_str) 151 | margin.append(0.5) # not used so far 152 | dataset = Dataset.from_dict({"prompt": prompts, "chosen": pos, "rejected": neg, "margin": margin}) 153 | if sanity_check: 154 | dataset = dataset.select(range(min(len(dataset), 100))) 155 | 156 | return dataset 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = HfArgumentParser(ScriptArguments) 161 | script_args = parser.parse_args_into_dataclasses()[0] 162 | 163 | # 1. load a pretrained model 164 | model = AutoModelForCausalLM.from_pretrained( 165 | script_args.model_name_or_path, 166 | use_flash_attention_2=True, 167 | torch_dtype=torch.float16, 168 | ) 169 | model.config.use_cache = False 170 | 171 | if script_args.ignore_bias_buffers: 172 | # torch distributed hack 173 | model._ddp_params_and_buffers_to_ignore = [ 174 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool 175 | ] 176 | 177 | if script_args.ref_model: 178 | ref_name = script_args.ref_model 179 | else: 180 | ref_name = script_args.model_name_or_path 181 | 182 | model_ref = AutoModelForCausalLM.from_pretrained( 183 | ref_name, 184 | torch_dtype=torch.bfloat16, 185 | use_flash_attention_2=True, 186 | ) 187 | tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) 188 | 189 | # 2. Load the paired dataset 190 | train_dataset = prepare_data( 191 | tokenizer, 192 | data_dir=script_args.train_dir, 193 | margin_scale=script_args.margin_scale, 194 | sanity_check=script_args.sanity_check, 195 | choose_type=script_args.choose_type, 196 | eot_token=script_args.eot_token, 197 | length_penalty=script_args.len_penalty, 198 | ) 199 | print(train_dataset) 200 | print(train_dataset[0]) 201 | if script_args.max_training_samples > 0: 202 | train_dataset = train_dataset.select(range(script_args.max_training_samples)) 203 | 204 | # 3. Load evaluation dataset 205 | eval_dataset = prepare_data( 206 | tokenizer, 207 | data_dir=script_args.eval_dir, 208 | sanity_check=True, 209 | margin_scale=script_args.margin_scale, 210 | eot_token=script_args.eot_token, 211 | ) 212 | 213 | # 4. initialize training arguments: 214 | 215 | training_args = TrainingArguments( 216 | per_device_train_batch_size=script_args.per_device_train_batch_size, 217 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 218 | # max_steps=script_args.max_steps, 219 | num_train_epochs=script_args.num_train_epochs, 220 | save_strategy=script_args.save_strategy, 221 | logging_steps=script_args.logging_steps, 222 | save_steps=script_args.save_steps, 223 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 224 | gradient_checkpointing=script_args.gradient_checkpointing, 225 | learning_rate=script_args.learning_rate, 226 | evaluation_strategy="steps", 227 | eval_steps=script_args.eval_steps, 228 | output_dir=script_args.output_dir, 229 | # report_to=script_args.report_to, 230 | lr_scheduler_type=script_args.lr_scheduler_type, 231 | warmup_steps=script_args.warmup_steps, 232 | # optim=script_args.optimizer_type, 233 | bf16=True, 234 | remove_unused_columns=False, 235 | run_name=script_args.run_name, 236 | save_only_model=True, 237 | ) 238 | print(training_args) 239 | 240 | # 5. initialize the DPO trainer 241 | 242 | dpo_trainer = PreferenceTrainer( 243 | model, 244 | model_ref, 245 | args=training_args, 246 | beta=script_args.beta, 247 | train_dataset=train_dataset, 248 | eval_dataset=eval_dataset, 249 | tokenizer=tokenizer, 250 | loss_type=script_args.loss_type, 251 | max_prompt_length=script_args.max_prompt_length, 252 | max_length=script_args.max_length, 253 | mask_prompt=script_args.mask_prompt, 254 | len_penalty=script_args.len_penalty, 255 | nll_coefficient=script_args.nll_coefficient, 256 | masking_user_turn=script_args.masking_user_turn, 257 | ) 258 | print("begin to train") 259 | 260 | # 6. train 261 | dpo_trainer.train() 262 | dpo_trainer.save_model(script_args.output_dir) 263 | 264 | # 7. save 265 | output_dir = os.path.join(script_args.output_dir, "final_checkpoint") 266 | dpo_trainer.model.save_pretrained(output_dir) 267 | -------------------------------------------------------------------------------- /alignment_algorithms/run_dpo.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_dpo.py \ 2 | --model_name_or_path RLHF4MATH/Gemma-7B-it-SFT3epoch \ 3 | --ref_model RLHF4MATH/Gemma-7B-it-SFT3epoch \ 4 | --per_device_train_batch_size 1 \ 5 | --num_train_epochs 1 \ 6 | --train_dir RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \ 7 | --eval_dir RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \ 8 | --learning_rate 2e-7 \ 9 | --lr_scheduler_type=cosine \ 10 | --gradient_accumulation_steps 4 \ 11 | --logging_steps 2 \ 12 | --eval_steps 10000 \ 13 | --output_dir=./mdpo_iter1_gemma7b_lr2e7_bz32 \ 14 | --warmup_ratio 0.1 \ 15 | --report_to wandb \ 16 | --bf16 \ 17 | --save_strategy=steps \ 18 | --save_steps=50 \ 19 | -------------------------------------------------------------------------------- /alignment_algorithms/run_kto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. 17 | 18 | # Full training: 19 | python examples/scripts/kto.py \ 20 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 21 | --per_device_train_batch_size 16 \ 22 | --num_train_epochs 1 \ 23 | --learning_rate 1e-5 \ 24 | --lr_scheduler_type=cosine \ 25 | --gradient_accumulation_steps 1 \ 26 | --logging_steps 10 \ 27 | --eval_steps 500 \ 28 | --output_dir=kto-aligned-model \ 29 | --warmup_ratio 0.1 \ 30 | --report_to wandb \ 31 | --bf16 \ 32 | --logging_first_step 33 | 34 | # QLoRA: 35 | python examples/scripts/kto.py \ 36 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 37 | --per_device_train_batch_size 8 \ 38 | --num_train_epochs 1 \ 39 | --learning_rate 1e-4 \ 40 | --lr_scheduler_type=cosine \ 41 | --gradient_accumulation_steps 1 \ 42 | --logging_steps 10 \ 43 | --eval_steps 500 \ 44 | --output_dir=kto-aligned-model-lora \ 45 | --warmup_ratio 0.1 \ 46 | --report_to wandb \ 47 | --bf16 \ 48 | --logging_first_step \ 49 | --use_peft \ 50 | --load_in_4bit \ 51 | --lora_target_modules=all-linear \ 52 | --lora_r=16 \ 53 | --lora_alpha=16 54 | """ 55 | 56 | from dataclasses import dataclass 57 | 58 | from datasets import load_dataset 59 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 60 | 61 | from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format 62 | 63 | 64 | # Define and parse arguments. 65 | @dataclass 66 | class ScriptArguments: 67 | """ 68 | The arguments for the KTO training script. 69 | """ 70 | 71 | dataset_name: str = "trl-lib/kto-mix-14k" 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) 76 | script_args, kto_args, model_args = parser.parse_args_into_dataclasses() 77 | 78 | # Load a pretrained model 79 | model = AutoModelForCausalLM.from_pretrained( 80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 81 | ) 82 | ref_model = AutoModelForCausalLM.from_pretrained( 83 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 84 | ) 85 | 86 | tokenizer = AutoTokenizer.from_pretrained( 87 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 88 | ) 89 | if tokenizer.pad_token is None: 90 | tokenizer.pad_token = tokenizer.eos_token 91 | 92 | # If we are aligning a base model, we use ChatML as the default template 93 | if tokenizer.chat_template is None: 94 | #model, tokenizer = setup_chat_format(model, tokenizer) 95 | print("No chat template is provided. Please set up the mode appropriately first.") 96 | 97 | # Load the dataset 98 | from datasets import Dataset, concatenate_datasets 99 | 100 | def get_chosen(example): 101 | chosen = example['chosen'] 102 | example['prompt'] = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True) 103 | example['completion'] = tokenizer.apply_chat_template(chosen, tokenize=False).replace(example['prompt'], "") 104 | example['label'] = True 105 | return example 106 | 107 | def get_rej(example): 108 | chosen = example['rejected'] 109 | example['prompt'] = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True) 110 | example['completion'] = tokenizer.apply_chat_template(chosen, tokenize=False).replace(example['prompt'], "") 111 | example['label'] = False 112 | return example 113 | 114 | ds = load_dataset(script_args.dataset_name, split='train') 115 | ds1 = ds.map(get_chosen, num_proc=32) 116 | ds2 = ds.map(get_rej, num_proc =32) 117 | 118 | dataset = concatenate_datasets([ds1, ds2]) 119 | dataset = dataset.shuffle(seed=42) 120 | 121 | # Initialize the KTO trainer 122 | kto_trainer = KTOTrainer( 123 | model, 124 | ref_model, 125 | args=kto_args, 126 | train_dataset=dataset, 127 | eval_dataset=dataset, 128 | tokenizer=tokenizer, 129 | peft_config=get_peft_config(model_args), 130 | ) 131 | 132 | # Train and push the model to the Hub 133 | kto_trainer.train() 134 | kto_trainer.save_model(kto_args.output_dir) 135 | kto_trainer.push_to_hub() 136 | -------------------------------------------------------------------------------- /alignment_algorithms/run_kto.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_kto.py \ 2 | --model_name_or_path RLHF4MATH/Gemma-7B-it-SFT3epoch \ 3 | --per_device_train_batch_size 1 \ 4 | --num_train_epochs 1 \ 5 | --learning_rate 2e-7 \ 6 | --lr_scheduler_type=cosine \ 7 | --gradient_accumulation_steps 4 \ 8 | --logging_steps 2 \ 9 | --eval_steps 1000 \ 10 | --output_dir=./mkto_iter1_gemma7b_lr2e7_bz32 \ 11 | --warmup_ratio 0.1 \ 12 | --report_to wandb \ 13 | --bf16 \ 14 | --logging_first_step \ 15 | --dataset_name=RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \ 16 | --save_strategy=steps \ 17 | --save_steps=50 \ 18 | --save_only_model=True \ 19 | -------------------------------------------------------------------------------- /alignment_algorithms/training_configs/zero2_pf.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: bf16 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /alignment_algorithms/training_configs/zero3_pf.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /assets/main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/assets/main_result.png -------------------------------------------------------------------------------- /inference/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/inference/.DS_Store -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | ## Inference for MATH Problem with External Python Intepreter 2 | 3 | In this repo, we implement the MATH problem solving with external python tool using the VLLM to accelerate infernece. The codebase is largely based on the Tora project. 4 | 5 | 6 | ## 1 Installation instructions 7 | 8 | **Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!!** 9 | 10 | 11 | Before starting, please make sure your linux machine has nvidia-cuda-toolkit installed. See SFT part for the guidance. 12 | 13 | 14 | ```sh 15 | conda create -n infer python=3.10.9 16 | conda activate infer 17 | # You may check nvcc -V , if no nvcc exists, you may run the following code 18 | # conda install nvidia/label/cuda-12.2.0::cuda-nvcc 19 | 20 | pip install datasets 21 | 22 | # The following code is tested for CUDA12.0-12.2 and CUDA12.6 23 | # To develop llama-3, mistral, gemma-1, 1.1, 2, deepseek you can consider the following vllm version 24 | pip install vllm==0.5.4 25 | 26 | pip install accelerate==0.33.0 27 | pip install deepspeed==0.14.5 28 | pip install numpy==1.26.4 #Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!! 29 | 30 | 31 | huggingface-cli login 32 | pip install pebble 33 | pip install timeout_decorator 34 | pip install ipython 35 | pip install sympy==1.12 36 | pip install antlr4-python3-runtime==4.11 # The versions of sympy and antlr4 cannot be modified!!!!! 37 | ``` 38 | 39 | ## 2 The General Process of Inference 40 | 41 | The current codes are implemented specially for Gemma, Mistral, deepseek, and Llama3 (mainly in terms of the prompt format). You should include gemma, mistral, deepseek, or llama in the model name so that the code can specify the prompt format. 42 | 43 | 44 | **Step 1** To start with, we prepare a prompt (problem) into the following Gemma format: 45 | 46 | ```python 47 | prompt = "user\nEvaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$.\nmodel" 48 | ``` 49 | 50 | **Step 2** The model received the prompt and generate some reasoning step and/or write some code. 51 | 52 | ```python 53 | response = "To evaluate the expression, we will use Python's sympy library.\npython\\nfrom sympy import ceiling, Rational\\n\\n# Evaluate the expression\\nexpression_result = ceiling(3 * (6 - Rational(1, 2)))\\n\\nexpression_result\\n" 54 | ``` 55 | 56 | After getting the response, we first detect whether the model outputs the final results by \\boxed{FINAL RESULT}. If not, we continue to detect whether the model outputs a python code needs to be executed by "\`\`\`python SOME CODE \`\`\`". If yes, we run the code and get the result. 57 | 58 | In this example, the execution result is "17". Therefore, we obtain the new prompt by adding the observation into the history: 59 | 60 | 61 | ```python 62 | prompt2 = "user\nEvaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$.\nmodel\nTo evaluate the expression, we will use Python's sympy library.\npython\\nfrom sympy import ceiling, Rational\\n\\n# Evaluate the expression\\nexpression_result = ceiling(3 * (6 - Rational(1, 2)))\\n\\nexpression_result\\n\nuser\noutput\\n17\\n\nmodel\n" 63 | ``` 64 | 65 | A new step begins and we stop either the model outputs the final answer or reaches the maximal number of tool calls. 66 | 67 | ## 3 Running the Generation Code 68 | 69 | To run the generation, the first step is to register the model as a server so that we can query it. You should first modify the scripts/register_server.sh according to your GPU setup. Then, run the following command with your model. 70 | 71 | ```sh 72 | bash scripts/register_server.sh google/gemma-1.1-7b-it 73 | ``` 74 | 75 | Then, you can run the scripts/infer_eval.sh, which will generate trajectories with the model and evaluate the output. 76 | 77 | ```sh 78 | bash scripts/infer_eval.sh gemma_7b 79 | ``` 80 | 81 | The iter_infer_to_collect_data.sh additionally takes a for loop to iteratively generate trajectories. The model name should contain mistral, gemma, llama, or deepseek so that the code can specify the prompt format. 82 | 83 | 84 | ## 4 Annotate Data 85 | 86 | ```sh 87 | python -um infer_data.annotate_data --data_name gsm8k --prompt_type tora --file_path ./collect_data/gemma_7b/gsm8k/train_tora_7473_seed1_t0.0_s0_e7473_09-22_16-22.jsonl --output_dir test_output.jsonl 88 | ``` 89 | -------------------------------------------------------------------------------- /inference/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # The evaluator is adapted from the ToRA project 2 | # https://github.com/microsoft/ToRA 3 | # ToRA authors: Zhibin Gou and Zhihong Shao and Yeyun Gong and yelong shen and Yujiu Yang and Minlie Huang and Nan Duan and Weizhu Chen 4 | 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from pebble import ProcessPool 9 | from concurrent.futures import TimeoutError 10 | 11 | from eval.grader import * 12 | from utils.parser import * 13 | from utils.utils import load_jsonl 14 | from utils.python_executor import PythonExecutor 15 | 16 | 17 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False): 18 | assert samples or file_path, "samples or file_path must be provided" 19 | if not samples: 20 | samples = list(load_jsonl(file_path)) 21 | # dedup by idx 22 | if 'idx' in samples[0]: 23 | samples = {sample['idx']: sample for sample in samples}.values() 24 | samples = sorted(samples, key=lambda x: x['idx']) 25 | else: 26 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)] 27 | 28 | if max_num_samples: 29 | print(f"max_num_samples: {max_num_samples} / {len(samples)}") 30 | samples = samples[:max_num_samples] 31 | 32 | # parse gt 33 | for sample in samples: 34 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name) 35 | 36 | # execute 37 | if ('pred' not in samples[0]) or execute: 38 | if "pal" in prompt_type: 39 | executor = PythonExecutor(get_answer_expr="solution()") 40 | else: 41 | executor = PythonExecutor(get_answer_from_stdout=True) 42 | 43 | for sample in tqdm(samples, desc="Execute"): 44 | sample['pred'] = [] 45 | sample['report'] = [] 46 | for code in sample['code']: 47 | pred, report = run_execute(executor, code, prompt_type, execute=True) 48 | sample['pred'].append(pred) 49 | sample['report'].append(report) 50 | 51 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']] 52 | 53 | scores = [] 54 | timeout_cnt = 0 55 | 56 | with ProcessPool() as pool: 57 | future = pool.map(math_equal_process, params, timeout=10) 58 | iterator = future.result() 59 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar: 60 | while True: 61 | try: 62 | result = next(iterator) 63 | scores.append(result) 64 | except StopIteration: 65 | break 66 | except TimeoutError as error: 67 | print(error) 68 | scores.append(False) 69 | timeout_cnt += 1 70 | except Exception as error: 71 | print(error.traceback) 72 | exit() 73 | progress_bar.update(1) 74 | 75 | idx = 0 76 | score_mat = [] 77 | for sample in samples: 78 | sample['score'] = scores[idx: idx+len(sample['pred'])] 79 | assert len(sample['score']) == len(sample['pred']) 80 | score_mat.append(sample['score']) 81 | idx += len(sample['pred']) 82 | 83 | max_len = max([len(s) for s in score_mat]) 84 | 85 | for i, s in enumerate(score_mat): 86 | if len(s) < max_len: 87 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad 88 | 89 | # output mean of each column of scores 90 | col_means= np.array(score_mat).mean(axis=0) 91 | mean_score = list(np.round(col_means * 100, decimals=1)) 92 | 93 | result_str = f"Num samples: {len(samples)}\n" \ 94 | f"Num scores: {len(scores)}\n" \ 95 | f"Timeout samples: {timeout_cnt}\n" \ 96 | f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ 97 | f"Mean score: {mean_score}\n" 98 | 99 | # each type score 100 | if "type" in samples[0]: 101 | type_scores = {} 102 | for sample in samples: 103 | if sample['type'] not in type_scores: 104 | type_scores[sample['type']] = [] 105 | type_scores[sample['type']].append(sample['score'][-1]) 106 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()} 107 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])} 108 | result_str += f"Type scores: {type_scores}\n" 109 | 110 | print(result_str) 111 | return result_str 112 | 113 | 114 | def parse_args(): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--data_name", type=str, default="math") 117 | parser.add_argument("--prompt_type", type=str, default="tora") 118 | parser.add_argument("--file_path", type=str, default=None, required=True) 119 | parser.add_argument("--max_num_samples", type=int, default=None) 120 | parser.add_argument("--execute", action="store_true") 121 | args = parser.parse_args() 122 | return args 123 | 124 | if __name__ == "__main__": 125 | args = parse_args() 126 | evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path, 127 | max_num_samples=args.max_num_samples, execute=args.execute) 128 | -------------------------------------------------------------------------------- /inference/eval/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 4 | - https://github.com/openai/prm800k 5 | """ 6 | import multiprocessing 7 | from math import isclose 8 | from typing import Union 9 | 10 | from sympy import simplify, N 11 | from sympy.parsing.sympy_parser import parse_expr 12 | from sympy.parsing.latex import parse_latex 13 | 14 | 15 | def is_digit(s): 16 | try: 17 | float(str(s).replace(",", "")) 18 | return True 19 | except ValueError: 20 | return False 21 | 22 | def math_equal(prediction: Union[bool, float, str], 23 | reference: Union[float, str], 24 | include_percentage: bool = True, 25 | is_close: bool = True, 26 | timeout: bool = False, 27 | ) -> bool: 28 | """ 29 | Exact match of math if and only if: 30 | 1. numerical equal: both can convert to float and are equal 31 | 2. symbolic equal: both can convert to sympy expression and are equal 32 | """ 33 | try: # 1. numerical equal 34 | if is_digit(prediction) and is_digit(reference): 35 | prediction = float(str(prediction).replace(",", "")) 36 | reference = float(str(reference).replace(",", "")) 37 | # number questions 38 | if include_percentage: 39 | gt_result = [reference / 100, reference, reference * 100] 40 | else: 41 | gt_result = [reference] 42 | for item in gt_result: 43 | try: 44 | if is_close: 45 | if isclose(item, prediction, rel_tol=1e-4): 46 | return True 47 | else: 48 | if item == prediction: 49 | return True 50 | except Exception: 51 | continue 52 | return False 53 | except: 54 | pass 55 | 56 | if not prediction and prediction not in [0, False]: 57 | return False 58 | 59 | # 2. symbolic equal 60 | reference = str(reference).strip() 61 | prediction = str(prediction).strip() 62 | 63 | ## deal with [], (), {} 64 | pred_str, ref_str = prediction, reference 65 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ 66 | (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): 67 | pred_str = pred_str.strip("[]()") 68 | ref_str = ref_str.strip("[]()") 69 | for s in ['{', "}", "(", ")"]: 70 | ref_str = ref_str.replace(s, "") 71 | pred_str = pred_str.replace(s, "") 72 | if pred_str == ref_str: 73 | return True 74 | 75 | ## [a, b] vs. [c, d], return a==c and b==d 76 | if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \ 77 | (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")): 78 | pred_parts = prediction[1:-1].split(",") 79 | ref_parts = reference[1:-1].split(",") 80 | if len(pred_parts) == len(ref_parts): 81 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 82 | return True 83 | 84 | # symbolic equal with sympy 85 | if timeout: 86 | if call_with_timeout(symbolic_equal_process, prediction, reference): 87 | return True 88 | else: 89 | if symbolic_equal(prediction, reference): 90 | return True 91 | 92 | return False 93 | 94 | 95 | def math_equal_process(param): 96 | return math_equal(param[-2], param[-1]) 97 | 98 | 99 | def symbolic_equal(a, b): 100 | def _parse(s): 101 | for f in [parse_latex, parse_expr]: 102 | try: 103 | return f(s) 104 | except: 105 | pass 106 | return s 107 | a = _parse(a) 108 | b = _parse(b) 109 | 110 | try: 111 | if simplify(a-b) == 0: 112 | return True 113 | except: 114 | pass 115 | 116 | try: 117 | if isclose(N(a), N(b), rel_tol=1e-3): 118 | return True 119 | except: 120 | pass 121 | return False 122 | 123 | 124 | def symbolic_equal_process(a, b, output_queue): 125 | result = symbolic_equal(a, b) 126 | output_queue.put(result) 127 | 128 | 129 | def call_with_timeout(func, *args, timeout=1, **kwargs): 130 | output_queue = multiprocessing.Queue() 131 | process_args = args + (output_queue,) 132 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 133 | process.start() 134 | process.join(timeout) 135 | 136 | if process.is_alive(): 137 | process.terminate() 138 | process.join() 139 | return False 140 | 141 | return output_queue.get() 142 | 143 | 144 | def _test_math_equal(): 145 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) 146 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) 147 | print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) 148 | 149 | if __name__ == "__main__": 150 | _test_math_equal() 151 | -------------------------------------------------------------------------------- /inference/infer_data/annotate_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This scrip evaluates the math data to check their correctiveness. 3 | """ 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from pebble import ProcessPool 8 | from concurrent.futures import TimeoutError 9 | 10 | from eval.grader import * 11 | from utils.parser import * 12 | from utils.utils import load_jsonl 13 | from utils.python_executor import PythonExecutor 14 | 15 | 16 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, execute=False): 17 | assert samples or file_path, "samples or file_path must be provided" 18 | if not samples: 19 | samples = list(load_jsonl(file_path)) 20 | # dedup by idx 21 | if 'idx' in samples[0]: 22 | samples = {sample['idx']: sample for sample in samples}.values() 23 | samples = sorted(samples, key=lambda x: x['idx']) 24 | else: 25 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)] 26 | 27 | # parse gt if not in the dataset 28 | if 'gt' in samples[0]: 29 | pass 30 | else: 31 | for sample in samples: 32 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name) 33 | 34 | # execute 35 | if ('pred' not in samples[0]) or execute: 36 | if "pal" in prompt_type: 37 | executor = PythonExecutor(get_answer_expr="solution()") 38 | else: 39 | executor = PythonExecutor(get_answer_from_stdout=True) 40 | 41 | for sample in tqdm(samples, desc="Execute"): 42 | sample['pred'] = [] 43 | sample['report'] = [] 44 | for code in sample['code']: 45 | pred, report = run_execute(executor, code, prompt_type, execute=True) 46 | sample['pred'].append(pred) 47 | sample['report'].append(report) 48 | 49 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']] 50 | 51 | scores = [] 52 | timeout_cnt = 0 53 | 54 | with ProcessPool() as pool: 55 | future = pool.map(math_equal_process, params, timeout=10) 56 | iterator = future.result() 57 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar: 58 | while True: 59 | try: 60 | result = next(iterator) 61 | scores.append(result) 62 | except StopIteration: 63 | break 64 | except TimeoutError as error: 65 | print(error) 66 | scores.append(False) 67 | timeout_cnt += 1 68 | except Exception as error: 69 | print(error.traceback) 70 | exit() 71 | progress_bar.update(1) 72 | 73 | idx = 0 74 | score_mat = [] 75 | for sample in samples: 76 | sample['score'] = scores[idx: idx+len(sample['pred'])] 77 | assert len(sample['score']) == len(sample['pred']) 78 | score_mat.append(sample['score']) 79 | idx += len(sample['pred']) 80 | 81 | max_len = max([len(s) for s in score_mat]) 82 | 83 | for i, s in enumerate(score_mat): 84 | if len(s) < max_len: 85 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad 86 | 87 | # output mean of each column of scores 88 | col_means= np.array(score_mat).mean(axis=0) 89 | mean_score = list(np.round(col_means * 100, decimals=1)) 90 | 91 | result_str = f"Num samples: {len(samples)}\n" \ 92 | f"Num scores: {len(scores)}\n" \ 93 | f"Timeout samples: {timeout_cnt}\n" \ 94 | f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ 95 | f"Mean score: {mean_score}\n" 96 | 97 | # each type score 98 | if "type" in samples[0]: 99 | type_scores = {} 100 | for sample in samples: 101 | if sample['type'] not in type_scores: 102 | type_scores[sample['type']] = [] 103 | type_scores[sample['type']].append(sample['score'][-1]) 104 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()} 105 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])} 106 | result_str += f"Type scores: {type_scores}\n" 107 | 108 | print(result_str) 109 | return result_str, samples 110 | 111 | 112 | def parse_args(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--data_name", type=str, default="math") 115 | parser.add_argument("--prompt_type", type=str, default="tora") 116 | parser.add_argument("--file_path", type=str, default=None, required=True) 117 | parser.add_argument("--max_num_samples", type=int, default=None) 118 | parser.add_argument("--execute", action="store_true") 119 | parser.add_argument("--output_dir", type=str, default=None, required=True) 120 | args = parser.parse_args() 121 | return args 122 | 123 | # data_name='gsm8k' 124 | # prompt_type = 'cot' / 'tora' 125 | # 126 | args = parse_args() 127 | eval_result, all_samples = evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path, execute=args.execute) 128 | 129 | with open(args.output_dir, "w", encoding="utf8") as f: 130 | for i in range(len(all_samples)): 131 | json.dump(all_samples[i], f, ensure_ascii=False) 132 | f.write('\n') 133 | 134 | -------------------------------------------------------------------------------- /inference/infer_data/get_dpo_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset 2 | 3 | import random 4 | import json 5 | 6 | # Define the file path and output dir 7 | data_files = 'here.jsonl' 8 | output_dir = "huggingface/dataset_name" 9 | 10 | # Load the dataset 11 | ds0 = load_dataset('json', data_files='here.jsonl', split='train') 12 | 13 | def parse_conversation(example): 14 | # Split the data into turns based on the start_of_turn and end_of_turn markers 15 | data = example["my_solu"][0] 16 | turns = re.split(r"|", data) 17 | 18 | # Clean and filter out empty entries 19 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")] 20 | 21 | # Create a list to hold the parsed conversation in the desired format 22 | conversation = [] 23 | 24 | # Process each turn, assigning the correct role and content 25 | for turn in turns: 26 | if turn.startswith("user\n"): 27 | # Extract content after the role identifier 28 | content = turn[5:].strip() 29 | conversation.append({"role": "user", "content": content}) 30 | elif turn.startswith("model\n"): 31 | content = turn[6:].strip() 32 | conversation.append({"role": "assistant", "content": content}) 33 | 34 | return {"messages": conversation} 35 | 36 | # we first transform the data into standard format 37 | ds1 = ds0.map(parse_conversation, num_proc=32) 38 | 39 | 40 | def filter_example(example): 41 | old_messages = example["messages"] 42 | 43 | if len(old_messages) < 4: 44 | return False 45 | 46 | if len(old_messages) % 2 != 0: 47 | return False 48 | 49 | all_mes_len = len(old_messages) 50 | # if the model makes mistake but predict the correct answer 51 | if "error" in old_messages[-2]["content"].lower(): 52 | return False 53 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower(): 54 | return False 55 | 56 | k = 0 57 | 58 | for mes in old_messages: 59 | if k % 2 != 0 and k < all_mes_len - 1: 60 | if "python" not in mes["content"]: 61 | return False 62 | k += 1 63 | # env error 64 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower(): 65 | return False 66 | 67 | return True 68 | 69 | ds2 = ds1.filter(filter_example, num_proc=32) 70 | 71 | 72 | # Function to de-duplicate and group entries 73 | def deduplicate_and_group(dataset): 74 | unique_entries = {} 75 | for entry in dataset: 76 | idx = entry['idx'] 77 | solu = entry['my_solu'][0] # tuples are hashable and can be used as dictionary keys 78 | if idx not in unique_entries: 79 | unique_entries[idx] = {} 80 | if solu not in unique_entries[idx]: 81 | unique_entries[idx][solu] = entry 82 | return unique_entries 83 | 84 | # Group by 'idx' and de-duplicate by 'my_solu' 85 | grouped_data = deduplicate_and_group(ds2) 86 | 87 | # Select one sample with scores [True] and one with scores [False] 88 | def select_samples(groups): 89 | selected_pairs = [] 90 | for idx, solutions in groups.items(): 91 | true_samples = [sol for sol in solutions.values() if sol['score'][0] == True] 92 | false_samples = [sol for sol in solutions.values() if sol['score'][0] == False] 93 | if true_samples and false_samples: 94 | selected_true = random.choice(true_samples) 95 | selected_false = random.choice(false_samples) 96 | selected_pairs.append((selected_true, selected_false)) 97 | return selected_pairs 98 | 99 | # Apply the selection function 100 | selected_samples = select_samples(grouped_data) 101 | 102 | 103 | # get the dpo dataset 104 | all_samples = [] 105 | for pair in selected_samples: 106 | all_samples.append( 107 | { 108 | "gt": pair[0]["gt"], 109 | "chosen": pair[0]["messages"], 110 | "rejected": pair[1]["messages"],} 111 | ) 112 | 113 | dict_data = { 114 | "rejected": [d['rejected'] for d in all_samples], 115 | "chosen": [d['chosen'] for d in all_samples], 116 | "gt": [d['gt'] for d in all_samples], 117 | } 118 | 119 | final_ds = Dataset.from_dict(dict_data) 120 | final_ds.push_to_hub(output_dir) 121 | -------------------------------------------------------------------------------- /inference/infer_data/infer_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This scrip support is adapted from Tora project and supports multi-rounds vllm inference. 3 | The inference is formulated as a multi-turn chat and the model should be registered as a server by scripts/register_server.sh first. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import random 9 | import time 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | from datetime import datetime 12 | 13 | import requests 14 | from eval.evaluate import evaluate 15 | from tqdm import tqdm 16 | from utils.data_loader import load_data 17 | from utils.parser import * 18 | from utils.python_executor import PythonExecutor 19 | from utils.utils import construct_prompt, load_jsonl, save_jsonl, set_seed 20 | from vllm import LLM, SamplingParams 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--data_name", default="gsm8k", type=str) 26 | parser.add_argument("--data_dir", default="./data", type=str) 27 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str) 28 | parser.add_argument("--output_dir", default="./output", type=str) 29 | parser.add_argument("--prompt_type", default="tora", type=str) 30 | parser.add_argument("--split", default="test", type=str) 31 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data 32 | parser.add_argument("--seed", default=0, type=int) 33 | parser.add_argument("--start", default=0, type=int) 34 | parser.add_argument("--end", default=-1, type=int) 35 | parser.add_argument("--temperature", default=0, type=float) 36 | parser.add_argument("--n_sampling", default=1, type=int) 37 | parser.add_argument("--top_p", default=1, type=float) 38 | parser.add_argument("--max_tokens_per_call", default=1024, type=int) 39 | parser.add_argument("--shuffle", action="store_true") 40 | parser.add_argument("--ports", action='append', default=[]) 41 | parser.add_argument("--horizon", default=6, type=int) # the maximal number of tool calls 42 | parser.add_argument("--eval", default=False, type=bool) 43 | 44 | args = parser.parse_args() 45 | args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm) 46 | return args 47 | 48 | 49 | def prepare_data(args): 50 | examples = load_data(args.data_name, args.split, args.data_dir) 51 | 52 | # sample `num_test_sample` from dataset 53 | if args.num_test_sample > 0: 54 | examples = random.sample(examples, args.num_test_sample) 55 | elif args.num_test_sample == -1: 56 | args.num_test_sample = len(examples) 57 | 58 | # shuffle 59 | if args.shuffle: 60 | random.seed(datetime.now().timestamp()) 61 | random.shuffle(examples) 62 | 63 | # select start and end 64 | if args.end == -1: 65 | args.end = len(examples) 66 | examples = examples[args.start : args.end] 67 | 68 | # get out_file name 69 | dt_string = datetime.now().strftime("%m-%d_%H-%M") 70 | model_name = "/".join(args.model_name_or_path.split("/")[-2:]) 71 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}" 72 | out_file = f"{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl" 73 | os.makedirs(f"{args.output_dir}/{model_name}/{args.data_name}", exist_ok=True) 74 | 75 | # load all processed samples 76 | # find the files in e.g. ./output/gemma2/math/ 77 | processed_files = [ 78 | f 79 | for f in os.listdir(f"{args.output_dir}/{model_name}/{args.data_name}/") 80 | if f.endswith(".jsonl") and f.startswith(out_file_prefix) 81 | ] 82 | processed_samples = [] 83 | for f in processed_files: 84 | processed_samples.extend(list(load_jsonl(f"{args.output_dir}/{model_name}/{args.data_name}/{f}"))) 85 | 86 | # dedepulicate 87 | processed_samples = {sample["idx"]: sample for sample in processed_samples} 88 | processed_idxs = list(processed_samples.keys()) 89 | processed_samples = list(processed_samples.values()) 90 | total_examples = len(examples) 91 | # if example has been inferenced with the same seed, temperature, and model, we skip them 92 | examples = [example for example in examples if example["idx"] not in processed_idxs] 93 | print(f"Idx {args.start} - {args.end}: Remain {len(examples)}/{total_examples} samples.") 94 | if len(examples) == 0: 95 | pass 96 | else: 97 | print(examples[0]) 98 | return examples, processed_samples, out_file 99 | 100 | 101 | def main(args): 102 | ports = args.ports 103 | examples, processed_samples, out_file = prepare_data(args) 104 | # init python executor 105 | executor = PythonExecutor(get_answer_from_stdout=True) 106 | print(args.prompt_type) 107 | 108 | SamplingParams.seed = args.seed 109 | # load model and determine the number of gpus used 110 | if "gemma" in args.model_name_or_path: 111 | stop_tokens = ["", "", "```output", ""] 112 | elif "mistral" in args.model_name_or_path: 113 | stop_tokens = ["", "", "[INST]", "```output"] 114 | elif "deepseek" in args.model_name_or_path: 115 | stop_tokens = ["<|end▁of▁sentence|>", "User", "```output"] 116 | elif "llama3" in args.model_name_or_path: 117 | stop_tokens = ["<|eot_id|>", "<|start_header_id|>user", "```output"] 118 | else: 119 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path) 120 | default_args = { 121 | "use_beam_search": False, 122 | "n": 1, 123 | "temperature": args.temperature, 124 | "max_tokens": 1024, 125 | "seed": args.seed, 126 | "top_p": 1.0, 127 | "top_k": -1, 128 | "stop": stop_tokens, 129 | } 130 | 131 | def query_model(prompt, args, port): 132 | json = { 133 | **args, 134 | "prompt": prompt, 135 | } 136 | response = requests.post(url="http://localhost" + ":" + str(port) + "/generate", json=json) 137 | response_json = response.json() 138 | return [response_json["text"][i][len(prompt) :] for i in range(len(response_json["text"]))] 139 | 140 | samples = [] 141 | 142 | for example in tqdm(examples, total=len(examples)): 143 | idx = example["idx"] 144 | 145 | # parse question and answer 146 | example["question"] = parse_question(example, args.data_name) 147 | gt_cot, gt_ans = parse_ground_truth(example, args.data_name) 148 | 149 | full_prompt = construct_prompt(args, example) 150 | 151 | sample = {"idx": idx, "question": example["question"], "gt_cot": gt_cot, "gt": gt_ans, "prompt": full_prompt} 152 | # add remain fields 153 | for key in [ 154 | "level", 155 | "type", 156 | "unit", 157 | "solution_type", 158 | "choices", 159 | "solution", 160 | "ques_type", 161 | "ans_type", 162 | "answer_type", 163 | "dataset", 164 | "subfield", 165 | "filed", 166 | "theorem", 167 | "answer", 168 | ]: 169 | if key in example: 170 | sample[key] = example[key] 171 | samples.append(sample) 172 | 173 | print("dataset:", args.data_name, "samples:", len(samples)) 174 | if len(samples) > 0: 175 | print("-" * 50) 176 | print("sample:", samples[0]["prompt"]) 177 | print("-" * 50) 178 | 179 | # repeat H times 180 | remain_prompts = [sample["prompt"] for sample in samples for _ in range(args.n_sampling)] 181 | remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)] 182 | all_gts = [sample["gt"] for sample in samples for _ in range(args.n_sampling)] 183 | 184 | tmp_idx = list(range(len(all_gts))) 185 | all_gts = dict(zip(tmp_idx, all_gts)) 186 | 187 | end_prompts = [] 188 | 189 | max_func_call = 1 if args.prompt_type == "cot" else args.horizon 190 | 191 | # start inference, measure time use 192 | start_time = time.time() 193 | print("The maxmial function call is ", max_func_call) 194 | for epoch in range(max_func_call): 195 | print("=" * 50, "Epoch", epoch) 196 | current_prompts = remain_prompts 197 | # if all the queries meet the stop criteria, break 198 | if len(current_prompts) == 0: 199 | break 200 | 201 | # get all outputs, each prompt is (idx, prompt_content) 202 | prompts = [item[1] for item in current_prompts] 203 | with ThreadPoolExecutor(512) as executor2: 204 | result = [ 205 | executor2.submit(query_model, prompts[i], default_args, ports[i % len(ports)]) 206 | for i in range(len(prompts)) 207 | ] 208 | # use tqdm to show progress 209 | for _ in tqdm(as_completed(result), total=len(result)): 210 | pass 211 | 212 | outputs = [r.result()[0] for r in result] 213 | 214 | # print(len(outputs), len(current_prompts)) 215 | 216 | if len(outputs) != len(current_prompts): 217 | raise ValueError("VLLM has some problem, the generated responsess are less than the queries.") 218 | 219 | # process all outputs 220 | remain_prompts = [] 221 | remain_codes = [] 222 | 223 | for (i, query), output in zip(current_prompts, outputs): 224 | output = output.rstrip() 225 | # append the y_s to the current state (history) 226 | query += output 227 | if args.prompt_type == "cot": 228 | # for cot, the prompt ends for one round 229 | end_prompts.append((i, query)) 230 | elif "boxed" not in output and "```python" in output: #output.endswith("```"): 231 | # the model does not output the final answer, meanwhile, a code needs to be executed 232 | program = extract_program(query) 233 | remain_prompts.append((i, query)) 234 | remain_codes.append(program) 235 | else: 236 | end_prompts.append((i, query)) 237 | 238 | # execute the codes and get the results 239 | # note that the order of remain_codes is the same as remain_prompts 240 | remain_results = executor.batch_apply(remain_codes) 241 | for k in range(len(remain_prompts)): 242 | i, query = remain_prompts[k] 243 | res, report = remain_results[k] 244 | exec_result = res if res else report 245 | # we add the observation to the history 246 | if "gemma" in args.model_name_or_path: 247 | exec_result = f"\nuser\n```output\n{exec_result}\n```\nmodel\n" 248 | elif "mistral" in args.model_name_or_path: 249 | exec_result = f" [INST] ```output\n{exec_result}\n``` [/INST]" 250 | elif "deepseek" in args.model_name_or_path: 251 | #exec_result = f"<|end▁of▁sentence|>User: ```output\n{exec_result}\n```\n\nAssistant:" 252 | #for deepseek, we directly append the observation as the training of deepseek 253 | exec_result = f"\n```output\n{exec_result}\n```\n" 254 | elif "llama3" in args.model_name_or_path: 255 | exec_result = f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n```output\n{exec_result}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 256 | else: 257 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path) 258 | 259 | query += exec_result 260 | 261 | if epoch == max_func_call - 1: 262 | query += "\nReach max function call limit." 263 | remain_prompts[k] = (i, query) 264 | 265 | # unsolved samples 266 | print("Unsolved samples:", len(remain_prompts)) 267 | end_prompts.extend(remain_prompts) 268 | # sort by idx 269 | end_prompts = sorted(end_prompts, key=lambda x: x[0]) 270 | 271 | if "gemma" in args.model_name_or_path: 272 | ans_split = "model\n" 273 | elif "mistral" in args.model_name_or_path: 274 | ans_split = "[/INST]" 275 | elif "deepseek" in args.model_name_or_path: 276 | ans_split = "\n\nAssistant:" 277 | elif "llama3" in args.model_name_or_path: 278 | ans_split = "<|start_header_id|>user<|end_header_id|>\n\n" 279 | else: 280 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path) 281 | 282 | codes = [prompt.split(ans_split)[-1].strip() for _, prompt in end_prompts] 283 | 284 | # extract preds, run_execute will extract the code needed to run... 285 | # for tora, we only extract the final answer but do not run the code 286 | results = [run_execute(executor, code, args.prompt_type) for code in codes] 287 | 288 | time_use = time.time() - start_time 289 | tmp_to_store = [z.split("---")[-1].strip() for _, z in end_prompts] 290 | # put results back to examples 291 | all_samples = [] 292 | for i, sample in enumerate(samples): 293 | code = codes[i * args.n_sampling : (i + 1) * args.n_sampling] 294 | result = results[i * args.n_sampling : (i + 1) * args.n_sampling] 295 | preds = [item[0] for item in result] 296 | reports = [item[1] for item in result] 297 | response_tmp = tmp_to_store[i * args.n_sampling : (i + 1) * args.n_sampling] 298 | sample.pop("prompt") 299 | sample.update({"my_solu": response_tmp, "code": code, "pred": preds, "report": reports}) 300 | all_samples.append(sample) 301 | 302 | # add processed samples 303 | all_samples.extend(processed_samples) 304 | save_jsonl(all_samples, out_file) 305 | 306 | # Evaluate the result 307 | if args.eval: 308 | result_str = evaluate(samples=all_samples, data_name=args.data_name, prompt_type=args.prompt_type, execute=True) 309 | result_str += f"\nTime use: {time_use:.2f}s" 310 | time_str = f"{int(time_use // 60)}:{int(time_use % 60):02d}" 311 | result_str += f"\nTime use: {time_str}" 312 | 313 | with open(out_file.replace(".jsonl", f"_{args.prompt_type}.metrics"), "w") as f: 314 | f.write(result_str) 315 | 316 | 317 | if __name__ == "__main__": 318 | args = parse_args() 319 | set_seed(args.seed) 320 | main(args) 321 | -------------------------------------------------------------------------------- /inference/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -eq 0 ]; then 3 | echo "Usage: $0 " 4 | exit 1 5 | fi 6 | MODEL_NAME_OR_PATH=$1 7 | 8 | # DATA_LIST = ['math', 'gsm8k', 'gsm-hard', 'svamp', 'tabmwp', 'asdiv', 'mawps'] 9 | 10 | DATA_NAME="gsm8k" 11 | 12 | OUTPUT_DIR="./output1" 13 | 14 | SPLIT="test" 15 | PROMPT_TYPE="tora" 16 | NUM_TEST_SAMPLE=-1 17 | 18 | 19 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ 20 | python -um infer_data.infer_eval \ 21 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 22 | --data_name ${DATA_NAME} \ 23 | --output_dir ${OUTPUT_DIR} \ 24 | --split ${SPLIT} \ 25 | --prompt_type ${PROMPT_TYPE} \ 26 | --num_test_sample ${NUM_TEST_SAMPLE} \ 27 | --seed 1 \ 28 | --temperature 0 \ 29 | --n_sampling 1 \ 30 | --top_p 1 \ 31 | --start 0 \ 32 | --end -1 \ 33 | --horizon 4 \ 34 | --ports "8000" \ 35 | --ports "8001" \ 36 | --ports "8002" \ 37 | --ports "8003" \ 38 | --ports "8004" \ 39 | --ports "8005" \ 40 | --ports "8006" \ 41 | --ports "8007" \ 42 | --eval eval \ 43 | 44 | 45 | 46 | DATA_NAME="math" 47 | 48 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ 49 | python -um infer_data.infer_eval \ 50 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 51 | --data_name ${DATA_NAME} \ 52 | --output_dir ${OUTPUT_DIR} \ 53 | --split ${SPLIT} \ 54 | --prompt_type ${PROMPT_TYPE} \ 55 | --num_test_sample ${NUM_TEST_SAMPLE} \ 56 | --seed 1 \ 57 | --temperature 0 \ 58 | --n_sampling 1 \ 59 | --top_p 1 \ 60 | --start 0 \ 61 | --end -1 \ 62 | --horizon 4 \ 63 | --ports "8000" \ 64 | --ports "8001" \ 65 | --ports "8002" \ 66 | --ports "8003" \ 67 | --ports "8004" \ 68 | --ports "8005" \ 69 | --ports "8006" \ 70 | --ports "8007" \ 71 | --eval eval \ 72 | -------------------------------------------------------------------------------- /inference/scripts/infer.sh: -------------------------------------------------------------------------------- 1 | if [ $# -eq 0 ]; then 2 | echo "Usage: $0 " 3 | exit 1 4 | fi 5 | MODEL_NAME_OR_PATH=$1 6 | 7 | #DATA_LIST = ['math', 'gsm8k'] 8 | 9 | DATA_NAME="gsm8k" 10 | 11 | OUTPUT_DIR="./collect_data" 12 | 13 | SPLIT="train" 14 | PROMPT_TYPE="tora" 15 | NUM_TEST_SAMPLE=-1 16 | 17 | 18 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ 19 | python -um infer_data.infer_eval \ 20 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 21 | --data_name ${DATA_NAME} \ 22 | --output_dir ${OUTPUT_DIR} \ 23 | --split ${SPLIT} \ 24 | --prompt_type ${PROMPT_TYPE} \ 25 | --num_test_sample ${NUM_TEST_SAMPLE} \ 26 | --seed 1 \ 27 | --temperature 0 \ 28 | --n_sampling 1 \ 29 | --top_p 1 \ 30 | --start 0 \ 31 | --end -1 \ 32 | --horizon 6 \ 33 | --ports "8000" \ 34 | --ports "8001" \ 35 | --ports "8002" \ 36 | --ports "8003" \ 37 | --ports "8004" \ 38 | --ports "8005" \ 39 | --ports "8006" \ 40 | --ports "8007" \ 41 | 42 | 43 | -------------------------------------------------------------------------------- /inference/scripts/iter_infer_to_collect_data.sh: -------------------------------------------------------------------------------- 1 | if [ $# -eq 0 ]; then 2 | echo "Usage: $0 " 3 | exit 1 4 | fi 5 | MODEL_NAME_OR_PATH=$1 6 | 7 | #DATA_NAME="math" 8 | DATA_NAME="gsm8k" 9 | OUTPUT_DIR="./iter_collect_data" 10 | 11 | SPLIT="train" 12 | PROMPT_TYPE="tora" 13 | NUM_TEST_SAMPLE=-1 14 | 15 | for ((i=0; i<=8000; i+=1)) 16 | do 17 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ 18 | python -um infer_data.infer_eval \ 19 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 20 | --data_name ${DATA_NAME} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --split ${SPLIT} \ 23 | --prompt_type ${PROMPT_TYPE} \ 24 | --num_test_sample ${NUM_TEST_SAMPLE} \ 25 | --seed 1 \ 26 | --temperature 0 \ 27 | --n_sampling 1 \ 28 | --top_p 1 \ 29 | --start 0 \ 30 | --end -1 \ 31 | --horizon 6 \ 32 | --ports "8000" \ 33 | --ports "8001" \ 34 | --ports "8002" \ 35 | --ports "8003" \ 36 | --ports "8004" \ 37 | --ports "8005" \ 38 | --ports "8006" \ 39 | --ports "8007" 40 | done 41 | -------------------------------------------------------------------------------- /inference/scripts/register_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check whether model path is provided 4 | if [ $# -eq 0 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | 10 | MODEL_PATH=$1 11 | 12 | # generate 8 server instances 13 | for i in {0..7} 14 | do 15 | CUDA_VISIBLE_DEVICES=$i python -m vllm.entrypoints.api_server \ 16 | --model $MODEL_PATH \ 17 | --gpu-memory-utilization=0.9 \ 18 | --max-num-seqs=200 \ 19 | --host 127.0.0.1 --tensor-parallel-size 1 \ 20 | --port $((8000+i)) \ 21 | & 22 | done 23 | -------------------------------------------------------------------------------- /inference/utils/annotate_data.py: -------------------------------------------------------------------------------- 1 | # The evaluator is adapted from the ToRA project 2 | # https://github.com/microsoft/ToRA 3 | # ToRA authors: Zhibin Gou and Zhihong Shao and Yeyun Gong and yelong shen and Yujiu Yang and Minlie Huang and Nan Duan and Weizhu Chen 4 | 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from pebble import ProcessPool 9 | from concurrent.futures import TimeoutError 10 | 11 | from eval.grader import * 12 | from utils.parser import * 13 | from utils.utils import load_jsonl 14 | from utils.python_executor import PythonExecutor 15 | from datasets import load_dataset, Dataset 16 | 17 | def evaluate(data_name, output_dir=None): 18 | if ".json" in data_name: 19 | ds = load_dataset("json", data_files=data_name, split='train', field='instances').shuffle(seed=42) 20 | else: 21 | ds = load_dataset(data_name, split="train").shuffle(seed=42) 22 | 23 | samples = [sample for sample in ds] 24 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']] 25 | 26 | scores = [] 27 | timeout_cnt = 0 28 | 29 | with ProcessPool() as pool: 30 | future = pool.map(math_equal_process, params, timeout=10) 31 | iterator = future.result() 32 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar: 33 | while True: 34 | try: 35 | result = next(iterator) 36 | scores.append(result) 37 | except StopIteration: 38 | break 39 | except TimeoutError as error: 40 | print(error) 41 | scores.append(False) 42 | timeout_cnt += 1 43 | except Exception as error: 44 | print(error.traceback) 45 | exit() 46 | progress_bar.update(1) 47 | 48 | idx = 0 49 | score_mat = [] 50 | for sample in samples: 51 | sample['score'] = scores[idx: idx+len(sample['pred'])] 52 | assert len(sample['score']) == len(sample['pred']) 53 | 54 | if ".json" in args.output_dir: 55 | all_data = [sample for sample in samples] 56 | output_eval_dataset = {} 57 | output_eval_dataset["type"] = "text_only" 58 | output_eval_dataset["instances"] = all_data 59 | print("I collect ", len(all_data), "samples") 60 | with open(output_dir, "w", encoding="utf8") as f: 61 | json.dump(output_eval_dataset, f, ensure_ascii=False) 62 | else: 63 | all_data = [sample for sample in samples] 64 | dict_data = { 65 | "idx": [d['idx'] for d in all_data], 66 | "gt": [d['gt'] for d in all_data], 67 | "level": [d['level'] for d in all_data], 68 | "type": [d['type'] for d in all_data], 69 | "messages": [d['messages'] for d in all_data], 70 | "pred": [d['pred'] for d in all_data], 71 | "score": [d['score'] for d in all_data], 72 | } 73 | 74 | dataset = Dataset.from_dict(dict_data) 75 | DatasetDict({'train': dataset}).push_to_hub(output_dir) 76 | 77 | 78 | def parse_args(): 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--data_name", type=str, default="math") 81 | parser.add_argument("--output_dir", type=str, default=None, required=True) 82 | args = parser.parse_args() 83 | return args 84 | 85 | if __name__ == "__main__": 86 | args = parse_args() 87 | evaluate(data_name=args.data_name, output_dir=args.output_dir) 88 | -------------------------------------------------------------------------------- /inference/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from datasets import load_dataset, Dataset, concatenate_datasets 5 | from utils.utils import load_jsonl, lower_keys 6 | 7 | def load_data(data_name, split, data_dir='./data'): 8 | data_file = f"{data_dir}/{data_name}/{split}.jsonl" 9 | if os.path.exists(data_file): 10 | examples = list(load_jsonl(data_file)) 11 | else: 12 | if data_name == "math": 13 | dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp") 14 | elif "RLHF4MATH/prompt_iter" in data_name: 15 | # if we use the pre-processed prompts 16 | dataset = load_dataset(data_name, split=split) 17 | elif data_name == "theorem-qa": 18 | dataset = load_dataset("wenhu/TheoremQA", split=split) 19 | elif data_name == "gsm8k": 20 | dataset = load_dataset(data_name, split=split) 21 | elif data_name == "gsm-hard": 22 | dataset = load_dataset("reasoning-machines/gsm-hard", split="train") 23 | elif data_name == "svamp": 24 | # evaluate on training set + test set 25 | dataset = load_dataset("ChilleD/SVAMP", split="train") 26 | dataset = concatenate_datasets([dataset, load_dataset("ChilleD/SVAMP", split="test")]) 27 | elif data_name == "asdiv": 28 | dataset = load_dataset("EleutherAI/asdiv", split="validation") 29 | dataset = dataset.filter(lambda x: ";" not in x['answer']) # remove multi-answer examples 30 | elif data_name == "mawps": 31 | examples = [] 32 | # four sub-tasks 33 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: 34 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) 35 | for example in sub_examples: 36 | example['type'] = data_name 37 | examples.extend(sub_examples) 38 | dataset = Dataset.from_list(examples) 39 | elif data_name == "finqa": 40 | dataset = load_dataset("dreamerdeo/finqa", split=split, name="main") 41 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 42 | elif data_name == "tabmwp": 43 | examples = [] 44 | with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f: 45 | data_dict = json.load(f) 46 | examples.extend(data_dict.values()) 47 | dataset = Dataset.from_list(examples) 48 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 49 | elif data_name == "bbh": 50 | examples = [] 51 | for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\ 52 | "date_understanding", "repeat_copy_logic", "object_counting"]: 53 | with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f: 54 | sub_examples = json.load(f)["examples"] 55 | for example in sub_examples: 56 | example['type'] = data_name 57 | examples.extend(sub_examples) 58 | dataset = Dataset.from_list(examples) 59 | else: 60 | raise NotImplementedError(data_name) 61 | 62 | examples = list(dataset) 63 | examples = [lower_keys(example) for example in examples] 64 | dataset = Dataset.from_list(examples) 65 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) 66 | dataset.to_json(data_file) 67 | 68 | # add 'idx' in the first column 69 | if 'idx' not in examples[0]: 70 | examples = [{'idx': i, **example} for i, example in enumerate(examples)] 71 | 72 | # dedepulicate & sort 73 | examples = sorted(examples, key=lambda x: x['idx']) 74 | return examples 75 | -------------------------------------------------------------------------------- /inference/utils/filter_data.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from datasets import Dataset, concatenate_datasets, load_dataset 3 | from transformers import AutoTokenizer, HfArgumentParser 4 | import argparse 5 | 6 | # Create an ArgumentParser object 7 | parser = argparse.ArgumentParser(description="Process some strings.") 8 | 9 | parser.add_argument("--data_dir", type=str, help="The dataset address", default=None) 10 | parser.add_argument("--output_dir", type=str, help="The output address", default=None) 11 | parser.add_argument("--model_name", type=str, help="The model used to collect the data", default='gemma') 12 | 13 | # Parse command line arguments 14 | args = parser.parse_args() 15 | 16 | 17 | # Step 1: load dataset 18 | if ".json" in args.data_dir: 19 | ds = load_dataset("json", data_files=args.data_dir, split='train', field='instances').shuffle(seed=42) 20 | else: 21 | ds = load_dataset(args.data_dir, split="train").shuffle(seed=42) 22 | 23 | # You may want to merge different datasets... 24 | #ds = concatenate_datasets([ds1, ds2]) 25 | 26 | # Step 2: we split the trajectory into the standard multi-turn format 27 | 28 | def parse_conversation(example, model_name='gemma'): 29 | # Split the data into turns based on the start_of_turn and end_of_turn markers 30 | 31 | if 'gemma' in model_name: 32 | data = example["my_solu"][0] 33 | turns = re.split(r"|", data) 34 | elif 'mistral' in model_name: 35 | data = example["my_solu"][0].replace("", "").replace("", "") 36 | turns = re.split(r"\[INST\]|\[/INST\]", data) 37 | else: 38 | raise NotImplementedError(model_name) 39 | 40 | # Clean and filter out empty entries 41 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")] 42 | 43 | # Create a list to hold the parsed conversation in the desired format 44 | conversation = [] 45 | 46 | if 'gemma' in model_name: 47 | for turn in turns: 48 | if turn.startswith("user\n"): 49 | # Extract content after the role identifier 50 | content = turn[5:].strip() 51 | conversation.append({"role": "user", "content": content}) 52 | elif turn.startswith("model\n"): 53 | content = turn[6:].strip() 54 | conversation.append({"role": "assistant", "content": content} 55 | 56 | elif 'mistral' in model_name: 57 | j = 0 58 | for turn in turns: 59 | if j % 2 == 0: 60 | content = turn.strip() 61 | conversation.append({"role": "user", "content": content}) 62 | j += 1 63 | else: 64 | content = turn.strip() 65 | conversation.append({"role": "assistant", "content": content} 66 | j += 1 67 | else: 68 | raise NotImplementedError(model_name) 69 | 70 | 71 | return {"messages": conversation} 72 | 73 | 74 | ds_new = ds.map(parse_conversation, num_proc=32) 75 | 76 | # Step 3: we filter the examples which are with ood rounds, make mistake in the second last round but still give a guess of the result 77 | 78 | 79 | def filter_example1(example): 80 | old_messages = example["messages"] 81 | 82 | if len(old_messages) < 4: 83 | return False 84 | 85 | if len(old_messages) % 2 != 0: 86 | return False 87 | 88 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower(): 89 | return False 90 | 91 | for mes in old_messages: 92 | # the code interpreter destroy the conda environment from time to time, we delete the samples collected when the env is wrong 93 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower(): 94 | return False 95 | if "```output\n[]" in mes["content"]: 96 | return False 97 | 98 | if "traitlets" in mes['content'] and 'error' in mes['content'].lower(): 99 | return False 100 | if "sympy.core.numbers" in mes['content'] and 'error' in mes['content'].lower(): 101 | return False 102 | if 'sympy.tensor.tensor' in mes['content'] and 'error' in mes['content']: 103 | return False 104 | if 'minlex() got an' in mes['content']: 105 | return False 106 | if 'No module named' in mes['content'] and 'sympy.' in mes['content']: 107 | return False 108 | if 'object is not subscriptable' in mes['content'].lower(): 109 | return False 110 | 111 | # We delete the samples that reach max function call 112 | # it does not influence the final result but can significantly accelerate the training process 113 | if 'Reach max function call' in mes['content']: 114 | return False 115 | 116 | return True 117 | 118 | ds_new = ds_new.filter(filter_example1, num_proc=32) 119 | 120 | 121 | # Step 4: we delete the samples that are too long 122 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 123 | 124 | def filter_too_long_pred(example): 125 | z = len(tokenizer.apply_chat_template(example["messages"], tokenize=True)) 126 | if z > 2048: 127 | return False 128 | return True 129 | 130 | ds_new = ds_new.filter(filter_too_long_pred, num_proc=32) 131 | 132 | 133 | # Step 5: output the filtered dataset 134 | 135 | # we delete the columns that are unnecessary 136 | columns_to_keep = ["idx", "gt", "level", "type", "messages", "pred"] 137 | ds_new = ds_new.remove_columns([col for col in dataset.column_names if col not in columns_to_keep]) 138 | 139 | 140 | if ".json" in args.output_dir: 141 | all_data = [sample for sample in ds_new] 142 | output_eval_dataset = {} 143 | output_eval_dataset["type"] = "text_only" 144 | output_eval_dataset["instances"] = all_data 145 | print("I collect ", len(all_data), "samples") 146 | with open(args.output_dir, "w", encoding="utf8") as f: 147 | json.dump(output_eval_dataset, f, ensure_ascii=False) 148 | else: 149 | ds_new.push_to_hub(args.output_dir) 150 | 151 | 152 | -------------------------------------------------------------------------------- /inference/utils/parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Dict 3 | 4 | 5 | def _fix_fracs(string): 6 | substrs = string.split("\\frac") 7 | new_str = substrs[0] 8 | if len(substrs) > 1: 9 | substrs = substrs[1:] 10 | for substr in substrs: 11 | new_str += "\\frac" 12 | if len(substr) > 0 and substr[0] == "{": 13 | new_str += substr 14 | else: 15 | try: 16 | assert len(substr) >= 2 17 | except: 18 | return string 19 | a = substr[0] 20 | b = substr[1] 21 | if b != "{": 22 | if len(substr) > 2: 23 | post_substr = substr[2:] 24 | new_str += "{" + a + "}{" + b + "}" + post_substr 25 | else: 26 | new_str += "{" + a + "}{" + b + "}" 27 | else: 28 | if len(substr) > 2: 29 | post_substr = substr[2:] 30 | new_str += "{" + a + "}" + b + post_substr 31 | else: 32 | new_str += "{" + a + "}" + b 33 | string = new_str 34 | return string 35 | 36 | 37 | def _fix_a_slash_b(string): 38 | if len(string.split("/")) != 2: 39 | return string 40 | a = string.split("/")[0] 41 | b = string.split("/")[1] 42 | try: 43 | if "sqrt" not in a: 44 | a = int(a) 45 | if "sqrt" not in b: 46 | b = int(b) 47 | assert string == "{}/{}".format(a, b) 48 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 49 | return new_string 50 | except: 51 | return string 52 | 53 | 54 | def _fix_sqrt(string): 55 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 56 | return _string 57 | 58 | 59 | def strip_string(string): 60 | string = str(string).strip() 61 | # linebreaks 62 | string = string.replace("\n", "") 63 | 64 | # right "." 65 | string = string.rstrip(".") 66 | 67 | # remove inverse spaces 68 | string = string.replace("\\!", "") 69 | string = string.replace("\\ ", "") 70 | 71 | # replace \\ with \ 72 | string = string.replace("\\\\", "\\") 73 | string = string.replace("\\\\", "\\") 74 | 75 | # replace tfrac and dfrac with frac 76 | string = string.replace("tfrac", "frac") 77 | string = string.replace("dfrac", "frac") 78 | 79 | # remove \left and \right 80 | string = string.replace("\\left", "") 81 | string = string.replace("\\right", "") 82 | 83 | # Remove unit: miles, dollars if after is not none 84 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 85 | if _string != "" and _string != string: 86 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 87 | string = _string 88 | 89 | # Remove circ (degrees) 90 | string = string.replace("^{\\circ}", "") 91 | string = string.replace("^\\circ", "") 92 | 93 | # remove dollar signs 94 | string = string.replace("\\$", "") 95 | string = string.replace("$", "") 96 | 97 | string = string.replace("\\text", "") 98 | string = string.replace("x\\in", "") 99 | 100 | # remove percentage 101 | string = string.replace("\\%", "") 102 | string = string.replace("\%", "") 103 | string = string.replace("%", "") 104 | 105 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 106 | string = string.replace(" .", " 0.") 107 | string = string.replace("{.", "{0.") 108 | 109 | # cdot 110 | string = string.replace("\\cdot", "") 111 | 112 | # inf 113 | string = string.replace("infinity", "\\infty") 114 | if "\\infty" not in string: 115 | string = string.replace("inf", "\\infty") 116 | string = string.replace("+\\inity", "\\infty") 117 | 118 | # and 119 | string = string.replace("and", "") 120 | string = string.replace("\\mathbf", "") 121 | 122 | # use regex to remove \mbox{...} 123 | string = re.sub(r"\\mbox{.*?}", "", string) 124 | 125 | # quote 126 | string.replace("'", "") 127 | string.replace('"', "") 128 | 129 | # i, j 130 | if "j" in string and "i" not in string: 131 | string = string.replace("j", "i") 132 | 133 | # replace a.000b where b is not number or b is end, with ab, use regex 134 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) 135 | string = re.sub(r"(\d+)\.0+$", r"\1", string) 136 | 137 | # if empty, return empty string 138 | if len(string) == 0: 139 | return string 140 | if string[0] == ".": 141 | string = "0" + string 142 | 143 | # to consider: get rid of e.g. "k = " or "q = " at beginning 144 | if len(string.split("=")) == 2: 145 | if len(string.split("=")[0]) <= 2: 146 | string = string.split("=")[1] 147 | 148 | string = _fix_sqrt(string) 149 | string = string.replace(" ", "") 150 | 151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 152 | string = _fix_fracs(string) 153 | 154 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 155 | string = _fix_a_slash_b(string) 156 | 157 | return string 158 | 159 | 160 | def extract_answer(pred_str): 161 | if "boxed" in pred_str: 162 | ans = pred_str.split("boxed")[-1] 163 | if len(ans) == 0: 164 | return "" 165 | elif ans[0] == "{": 166 | stack = 1 167 | a = "" 168 | for c in ans[1:]: 169 | if c == "{": 170 | stack += 1 171 | a += c 172 | elif c == "}": 173 | stack -= 1 174 | if stack == 0: 175 | break 176 | a += c 177 | else: 178 | a += c 179 | else: 180 | a = ans.split("$")[0].strip() 181 | pred = a 182 | elif "he answer is" in pred_str: 183 | pred = pred_str.split("he answer is")[-1].strip() 184 | elif extract_program_output(pred_str) != "": 185 | # fall back to program 186 | pred = extract_program_output(pred_str) 187 | else: # use the last number 188 | pattern = "-?\d*\.?\d+" 189 | pred = re.findall(pattern, pred_str.replace(",", "")) 190 | if len(pred) >= 1: 191 | pred = pred[-1] 192 | else: 193 | pred = "" 194 | 195 | # multiple line 196 | pred = pred.split("\n")[0] 197 | if pred != "" and pred[0] == ":": 198 | pred = pred[1:] 199 | if pred != "" and pred[-1] == ".": 200 | pred = pred[:-1] 201 | if pred != "" and pred[-1] == "/": 202 | pred = pred[:-1] 203 | pred = strip_string(pred) 204 | return pred 205 | 206 | 207 | ''' 208 | def extract_program(result: str, last_only=True): 209 | """ 210 | extract the program after "```python", and before "```" 211 | """ 212 | program = "" 213 | start = False 214 | for line in result.split("\n"): 215 | if line.startswith("```python"): 216 | if last_only: 217 | program = "" # only extract the last program 218 | else: 219 | program += "\n# ========\n" 220 | start = True 221 | elif line.startswith("```"): 222 | start = False 223 | elif start: 224 | program += line + "\n" 225 | return program 226 | ''' 227 | 228 | 229 | def extract_program(result: str, last_only=False): 230 | """ 231 | extract the program after "```python", and before "```" 232 | """ 233 | all_program = [] 234 | start = False 235 | program = "" 236 | for line in result.split("\n"): 237 | # if line.startswith("```python"): 238 | if "```python" in line: 239 | program = "" 240 | # if last_only: 241 | # program = "" # only extract the last program 242 | # else: 243 | # program += "\n# ========\n" 244 | 245 | start = True 246 | elif line.startswith("```"): 247 | start = False 248 | all_program.append(program) 249 | program = "" 250 | elif start: 251 | program += line + "\n" 252 | return all_program 253 | 254 | 255 | def extract_program_output(pred_str): 256 | """ 257 | extract output between the last ```output\n...\n``` 258 | """ 259 | if "```output" not in pred_str: 260 | return "" 261 | if "```output" in pred_str: 262 | pred_str = pred_str.split("```output")[-1] 263 | if "```" in pred_str: 264 | pred_str = pred_str.split("```")[0] 265 | output = pred_str.strip() 266 | return output 267 | 268 | 269 | def parse_ground_truth(example: Dict[str, Any], data_name): 270 | if "gt_cot" in example: 271 | return example["gt_cot"], strip_string(example["gt"]) 272 | 273 | # parse ground truth 274 | if data_name in ["math", "ocw"]: 275 | gt_cot = example["solution"] 276 | gt_ans = extract_answer(gt_cot) 277 | elif data_name == "gsm8k": 278 | gt_cot, gt_ans = example["answer"].split("####") 279 | elif data_name == "gsm-hard": 280 | gt_cot, gt_ans = example["code"], example["target"] 281 | elif data_name == "svamp": 282 | gt_cot, gt_ans = example["Equation"], example["Answer"] 283 | elif data_name == "asdiv": 284 | gt_cot = example["formula"] 285 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) 286 | elif data_name == "mawps": 287 | gt_cot, gt_ans = None, example["target"] 288 | elif data_name == "tabmwp": 289 | gt_cot = example["solution"] 290 | gt_ans = example["answer"] 291 | if example["ans_type"] in ["integer_number", "decimal_number"]: 292 | if "/" in gt_ans: 293 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) 294 | elif "," in gt_ans: 295 | gt_ans = float(gt_ans.replace(",", "")) 296 | elif "%" in gt_ans: 297 | gt_ans = float(gt_ans.split("%")[0]) / 100 298 | else: 299 | gt_ans = float(gt_ans) 300 | elif data_name == "bbh": 301 | gt_cot, gt_ans = None, example["target"] 302 | else: 303 | raise NotImplementedError(data_name) 304 | # post process 305 | gt_cot = str(gt_cot).strip() 306 | gt_ans = strip_string(gt_ans) 307 | return gt_cot, gt_ans 308 | 309 | 310 | def parse_question(example, data_name): 311 | question = "" 312 | if data_name == "asdiv": 313 | question = f"{example['body'].strip()} {example['question'].strip()}" 314 | elif data_name == "svamp": 315 | body = example["Body"].strip() 316 | if not body.endswith("."): 317 | body = body + "." 318 | question = f'{body} {example["Question"].strip()}' 319 | elif data_name == "tabmwp": 320 | title_str = f'regarding "{example["table_title"]}" ' if example["table_title"] else "" 321 | question = f"Read the following table {title_str}and answer a question:\n" 322 | question += f'{example["table"]}\n{example["question"]}' 323 | if example["choices"]: 324 | question += f' Please select from the following options: {example["choices"]}' 325 | else: 326 | for key in ["question", "problem", "Question", "input"]: 327 | if key in example: 328 | question = example[key] 329 | break 330 | assert question != "" 331 | return question.strip() 332 | 333 | 334 | def run_execute(executor, result, prompt_type, execute=False): 335 | if not result or result == "error": 336 | return None, None 337 | report = None 338 | 339 | if "program_only" in prompt_type: 340 | prediction = extract_program_output(result) 341 | elif prompt_type in ["pot", "pal"] and execute: 342 | code = extract_program(result) 343 | prediction, report = executor.apply(code) 344 | else: 345 | prediction = extract_answer(result) 346 | 347 | prediction = strip_string(prediction) 348 | return prediction, report 349 | -------------------------------------------------------------------------------- /inference/utils/python_executor.py: -------------------------------------------------------------------------------- 1 | ############## 2 | # Modified from ToRA and Math Instruct project. 3 | ############### 4 | import copy 5 | import datetime 6 | import io 7 | import json 8 | import os 9 | import pickle 10 | import re 11 | import traceback 12 | from concurrent.futures import TimeoutError 13 | from contextlib import redirect_stdout 14 | from functools import partial 15 | from typing import Any, Dict, Optional 16 | 17 | import dateutil.relativedelta 18 | import multiprocess 19 | import regex 20 | from multiprocess import Pool 21 | from pebble import ProcessPool 22 | from timeout_decorator import timeout 23 | from tqdm import tqdm 24 | 25 | NOT_EXECUTED = "" 26 | EXECUTION_ERROR = "Execution error:" 27 | SYNTAX_ERROR = "Syntax error:" 28 | RESULT_NOT_DEFINED_ERROR = "Result is not defined" 29 | TIMEOUT_ERROR = "timeout" 30 | UNDEFINED_ERROR = "Undefined error:" 31 | ERROR_PREFIXES = (EXECUTION_ERROR, SYNTAX_ERROR, RESULT_NOT_DEFINED_ERROR, TIMEOUT_ERROR, UNDEFINED_ERROR) 32 | 33 | 34 | def remove_ansi_escape_codes(text): 35 | """ 36 | Remove ANSI escape codes from the given text. 37 | 38 | Args: 39 | text (str): The text from which to remove ANSI escape codes. 40 | 41 | Returns: 42 | str: The text with ANSI escape codes removed. 43 | """ 44 | # ANSI escape codes start with the escape character followed by '[' 45 | # and end with a lowercase or uppercase letter. 46 | if not text: 47 | return text 48 | ansi_escape = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") 49 | return ansi_escape.sub("", text) 50 | 51 | 52 | def extract_after_traceback(text): 53 | """ 54 | Extract and return the content of the text after the first occurrence of "Traceback". 55 | 56 | Args: 57 | text (str): The text from which to extract content after "Traceback". 58 | 59 | Returns: 60 | str: The content after "Traceback", or the whole text if "Traceback" is not found. 61 | """ 62 | # Split the text at the first occurrence of "Traceback" 63 | parts = text.split("Traceback", 1) 64 | 65 | # Check if the split actually found "Traceback" and split the text 66 | if len(parts) > 1: 67 | # Return the part after "Traceback" 68 | return "Traceback" + parts[1] 69 | else: 70 | # Return the original text if "Traceback" is not found 71 | return text 72 | 73 | 74 | def get_error(txt): 75 | tmp = extract_after_traceback(remove_ansi_escape_codes(txt)) 76 | return tmp.split("\n\n")[-1] 77 | 78 | 79 | class GenericRuntime: 80 | GLOBAL_DICT = {} 81 | LOCAL_DICT = None 82 | HEADERS = [] 83 | 84 | def __init__(self): 85 | self._global_vars = copy.copy(self.GLOBAL_DICT) 86 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 87 | 88 | for c in self.HEADERS: 89 | self.exec_code(c) 90 | 91 | def exec_code(self, code_piece: str) -> None: 92 | # if the code contains input() or os.system(), return Error 93 | if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(r"(\s|^)?os.system\(", code_piece): 94 | raise RuntimeError() 95 | # exec is a built-in python function to execute python code 96 | # _global_vars is a dict containing the global variables that can be used and modified by the code_piece 97 | exec(code_piece, self._global_vars) 98 | 99 | def eval_code(self, expr: str) -> Any: 100 | """ 101 | # Evaluate a simple expression 102 | result = evaluator.eval_code("3 + 4") 103 | print(result) # Output: 7 104 | 105 | # Define a variable in the global context and use it in an expression 106 | evaluator._global_vars['x'] = 10 107 | result = evaluator.eval_code("x * 2") 108 | print(result) # Output: 20 109 | 110 | # Modify a variable in the global context through evaluation 111 | evaluator.eval_code("x = x + 5") 112 | print(evaluator._global_vars['x']) # Output: 15 113 | """ 114 | return eval(expr, self._global_vars) 115 | 116 | def inject(self, var_dict: Dict[str, Any]) -> None: 117 | for k, v in var_dict.items(): 118 | self._global_vars[k] = v 119 | 120 | @property 121 | def answer(self): 122 | return self._global_vars["answer"] 123 | 124 | 125 | class DateRuntime(GenericRuntime): 126 | GLOBAL_DICT = { 127 | "datetime": datetime.datetime, 128 | "timedelta": dateutil.relativedelta.relativedelta, 129 | "relativedelta": dateutil.relativedelta.relativedelta, 130 | } 131 | 132 | 133 | class CustomDict(dict): 134 | def __iter__(self): 135 | return list(super().__iter__()).__iter__() 136 | 137 | 138 | class ColorObjectRuntime(GenericRuntime): 139 | GLOBAL_DICT = {"dict": CustomDict} 140 | 141 | 142 | class PythonExecutor: 143 | def __init__( 144 | self, 145 | runtime: Optional[Any] = None, 146 | get_answer_symbol: Optional[str] = None, 147 | get_answer_expr: Optional[str] = None, 148 | get_answer_from_stdout: bool = False, 149 | timeout_length: int = 20, 150 | ) -> None: 151 | self.runtime = runtime if runtime else GenericRuntime() 152 | self.answer_symbol = get_answer_symbol 153 | self.answer_expr = get_answer_expr 154 | self.get_answer_from_stdout = get_answer_from_stdout 155 | self.pool = Pool(multiprocess.cpu_count()) 156 | self.timeout_length = timeout_length 157 | 158 | def process_generation_to_code(self, gens: str): 159 | return [g.split("\n") for g in gens] 160 | 161 | @staticmethod 162 | def execute( 163 | code, 164 | get_answer_from_stdout=None, 165 | runtime=None, 166 | answer_symbol=None, 167 | answer_expr=None, 168 | timeout_length=10, 169 | ): 170 | try: 171 | if get_answer_from_stdout: 172 | # io to the memory 173 | program_io = io.StringIO() 174 | # redirect_stdout: move all the standard output to the program_io 175 | with redirect_stdout(program_io): 176 | # run the code for at most timeout_length seconds and get all the output to program_io 177 | timeout(timeout_length)(runtime.exec_code)("\n".join(code)) 178 | # move the the begging of the outputs 179 | program_io.seek(0) 180 | result = program_io.read() 181 | elif answer_symbol: 182 | timeout(timeout_length)(runtime.exec_code)("\n".join(code)) 183 | result = runtime._global_vars[answer_symbol] 184 | elif answer_expr: 185 | timeout(timeout_length)(runtime.exec_code)("\n".join(code)) 186 | # eval_code(answer_expr), possibly because the global random variables are modified and can be used.. 187 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 188 | else: 189 | timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1])) 190 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 191 | report = "Done" 192 | # str(result) 193 | pickle.dumps(result) # serialization check 194 | except: 195 | report = traceback.format_exc().split("\n")[-2] 196 | result = json.dumps({"result": "", "error_message": report}) 197 | return result, report 198 | 199 | def apply(self, code): 200 | return self.batch_apply([code])[0] 201 | 202 | def batch_apply(self, batch_code_seq): 203 | # We will format the codes to be executed into the Jupyter format and then run the code. 204 | # The observation is captured as the standard output of the newest code 205 | # In other words, the models can write multi-rounds of code. All the codes will be executed but only the output of the last round will be captured and returned. 206 | all_processed_codes = [] 207 | for code_seq in batch_code_seq: 208 | 209 | z = """ 210 | import traceback 211 | import json 212 | import os 213 | import warnings 214 | warnings.filterwarnings('ignore') 215 | os.environ['OPENBLAS_NUM_THREADS'] = '16' 216 | 217 | from IPython.core.interactiveshell import InteractiveShell 218 | from IPython.utils import io 219 | code_snippets = [] 220 | """ 221 | for code_snippet in code_seq: 222 | # z += f'\ncode_snippets.append("""{code_snippet}""")\n' 223 | escaped_code_snippet = code_snippet.replace('"""', '\\"\\"\\"') 224 | z += f'\ncode_snippets.append("""{escaped_code_snippet}""")\n' 225 | z += f""" 226 | try: 227 | shell = InteractiveShell() 228 | for tmp_code in code_snippets: 229 | with io.capture_output() as captured: 230 | exec_result = shell.run_cell(tmp_code) 231 | output = f"{{captured.stdout}}{{captured.stderr}}".strip().replace("Out[1]: ", "") 232 | error_message = '' 233 | if exec_result.error_in_exec is not None: 234 | error_message = f"{EXECUTION_ERROR} {{str(exec_result.error_in_exec)}}" 235 | elif exec_result.error_before_exec is not None: 236 | # full traceback will be part of output 237 | error_message = f"{SYNTAX_ERROR} {{str(exec_result.error_before_exec)}}" 238 | elif output == "": 239 | error_message = "{RESULT_NOT_DEFINED_ERROR}" 240 | to_return = {{"result": output, "error_message": error_message}} 241 | except Exception: 242 | # removing useless prefix from traceback 243 | to_return = {{ 244 | "result": None, 245 | "error_message": "{UNDEFINED_ERROR}" + "\\n".join(traceback.format_exc().split("\\n")[3:]), 246 | }} 247 | print(json.dumps(to_return)) 248 | """ 249 | all_processed_codes.append(z) 250 | my_results = self.old_batch_apply(all_processed_codes) 251 | # Extract the old result 252 | batch_results = [] 253 | 254 | for prediction in my_results: 255 | if prediction[0]: 256 | if "Timeout Error" in prediction[0]: 257 | batch_results.append(("Timeout Error", "")) 258 | continue 259 | try: 260 | dict_data = json.loads(prediction[0]) 261 | except: 262 | match = re.search(r'"error_message":\s*"([^"]*)"', prediction[0]) 263 | if match: 264 | batch_results.append(("", match.group(1))) 265 | else: 266 | batch_results.append( 267 | ( 268 | "There exists some error in your code. Please rewrite the code and solve the problem.", 269 | "There exists some error in your code. Please rewrite the code and solve the problem.", 270 | ) 271 | ) 272 | continue 273 | else: 274 | batch_results.append( 275 | ( 276 | "There exists some error in your code. Please rewrite the code and solve the problem.", 277 | "There exists some error in your code. Please rewrite the code and solve the problem.", 278 | ) 279 | ) 280 | continue 281 | 282 | try: 283 | dict_data["error_message"] 284 | except: 285 | batch_results.append( 286 | ( 287 | "There exists some error in your code. Please rewrite the code and solve the problem.", 288 | "There exists some error in your code. Please rewrite the code and solve the problem.", 289 | ) 290 | ) 291 | continue 292 | if dict_data["error_message"]: 293 | if dict_data["result"]: 294 | batch_results.append((get_error(dict_data["result"]), "")) 295 | else: 296 | batch_results.append((dict_data["error_message"], "")) 297 | else: 298 | batch_results.append((dict_data["result"], "")) 299 | return batch_results 300 | 301 | @staticmethod 302 | def truncate(s, max_length=100): 303 | half = max_length // 2 304 | if len(s) > max_length: 305 | s = s[:half] + "..." + s[-half:] 306 | return s 307 | 308 | def old_batch_apply(self, batch_code): 309 | 310 | all_code_snippets = self.process_generation_to_code(batch_code) 311 | 312 | timeout_cnt = 0 313 | all_exec_results = [] 314 | with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool: 315 | executor = partial( 316 | self.execute, 317 | get_answer_from_stdout=self.get_answer_from_stdout, 318 | runtime=self.runtime, 319 | answer_symbol=self.answer_symbol, 320 | answer_expr=self.answer_expr, 321 | timeout_length=self.timeout_length, # this timeout not work 322 | ) 323 | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length) 324 | iterator = future.result() 325 | 326 | if len(all_code_snippets) > 100: 327 | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") 328 | else: 329 | progress_bar = None 330 | 331 | while True: 332 | try: 333 | result = next(iterator) 334 | all_exec_results.append(result) 335 | except StopIteration: 336 | break 337 | except TimeoutError as error: 338 | all_exec_results.append(("Timeout Error", "Timeout Error")) 339 | timeout_cnt += 1 340 | except Exception as error: 341 | # print(error) 342 | exit() 343 | if progress_bar is not None: 344 | progress_bar.update(1) 345 | 346 | if progress_bar is not None: 347 | progress_bar.close() 348 | 349 | batch_results = [] 350 | for code, (res, report) in zip(all_code_snippets, all_exec_results): 351 | # post processing 352 | res, report = str(res).strip(), str(report).strip() 353 | # res, report = self.truncate(res), self.truncate(report) 354 | batch_results.append((res.strip().replace("Out[1]: ", ""), report)) 355 | return batch_results 356 | -------------------------------------------------------------------------------- /inference/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from pathlib import Path 5 | from typing import Any, Iterable, Union 6 | 7 | import numpy as np 8 | 9 | 10 | def set_seed(seed: int = 42) -> None: 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | os.environ["PYTHONHASHSEED"] = str(seed) 14 | print(f"Random seed set as {seed}") 15 | 16 | 17 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 18 | with open(file, "r", encoding="utf-8") as f: 19 | for line in f: 20 | try: 21 | yield json.loads(line) 22 | except: 23 | print("Error in loading:", line) 24 | exit() 25 | 26 | 27 | def save_jsonl(samples, save_path): 28 | # ensure path 29 | folder = os.path.dirname(save_path) 30 | os.makedirs(folder, exist_ok=True) 31 | 32 | with open(save_path, "w", encoding="utf-8") as f: 33 | for sample in samples: 34 | f.write(json.dumps(sample) + "\n") 35 | print("Saved to", save_path) 36 | 37 | 38 | def lower_keys(example): 39 | new_example = {} 40 | for key, value in example.items(): 41 | if key != key.lower(): 42 | new_key = key.lower() 43 | new_example[new_key] = value 44 | else: 45 | new_example[key] = value 46 | return new_example 47 | 48 | 49 | def construct_prompt(args, example): 50 | if args.prompt_type == "tora": 51 | if "gemma" in args.model_name_or_path: 52 | full_prompt = f"user\n{example['question']}\nmodel\n" 53 | elif "mistral" in args.model_name_or_path: 54 | full_prompt = f" [INST] {example['question']} [/INST]" 55 | elif "deepseek" in args.model_name_or_path: 56 | full_prompt = f"User: {example['question']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{{}}.\n\nAssistant: " 57 | elif "llama3" in args.model_name_or_path: 58 | full_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{{}}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 59 | else: 60 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path) 61 | elif args.prompt_type == "cot": 62 | if "gemma" in args.model_name_or_path: 63 | full_prompt = f"user\n{example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nmodel\n" 64 | elif "mistral" in args.model_name_or_path: 65 | full_prompt = f" [INST] {example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}. [/INST]" 66 | elif "deepseek" in args.model_name_or_path: 67 | full_prompt = f"User: {example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant: " 68 | elif "llama3" in args.model_name_or_path: 69 | full_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 70 | else: 71 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path) 72 | 73 | return full_prompt 74 | 75 | 76 | key_map = { 77 | "gt": "Ground Truth", 78 | "pred": "Prediction", 79 | "gt_cot": "Reference CoT", 80 | "score": "Score", 81 | } 82 | 83 | 84 | def show_sample(sample, print_all_preds=False): 85 | print("==" * 20) 86 | for key in ["idx", "type", "level", "dataset"]: 87 | if key in sample: 88 | # capitalize 89 | print("{}: {}".format(key[0].upper() + key[1:], sample[key])) 90 | print("Question:", repr(sample["question"])) 91 | if "code" in sample: 92 | if print_all_preds: 93 | for code in sample["code"]: 94 | print("-" * 20) 95 | print("code:", code) 96 | print("Execution:", sample["report"]) 97 | else: 98 | print("Solution:\n", sample["code"][0]) 99 | print("Execution:", sample["report"][0]) 100 | if "pred" in sample: 101 | print("Prediction:", repr(sample["pred"][0])) 102 | for key in ["gt", "score", "unit", "gt_cot"]: 103 | if key in sample: 104 | _key = key_map.get(key, key) 105 | print("{}: {}".format(_key, repr(sample[key]))) 106 | print() 107 | -------------------------------------------------------------------------------- /useful_codes/annotate_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from collections import defaultdict 3 | import multiprocessing 4 | from math import isclose 5 | from typing import Union 6 | 7 | from datasets import Dataset, concatenate_datasets 8 | 9 | from typing import Union 10 | 11 | from sympy import N, simplify 12 | from sympy.parsing.latex import parse_latex 13 | from sympy.parsing.sympy_parser import parse_expr 14 | 15 | ds1 = load_dataset("1231czx/7b_sft_510k_3epoch_gen_data_iter1", split="train").shuffle(seed=42) 16 | ds2 = load_dataset("1231czx/7b_sft_510k_1epoch_gen_data_iter1", split="train").shuffle(seed=42) 17 | 18 | ds = concatenate_datasets([ds1, ds2]) # .select(range(100000)) 19 | # ds = ds.select(range(5000)) 20 | N_pair = 1 21 | data_comp = "1231czx/7B_iter1_dpo_N1_random_pair" 22 | data_sft = "1231czx/7B_iter1_sft_N1" 23 | 24 | 25 | 26 | def symbolic_equal(a, b): 27 | def _parse(s): 28 | for f in [parse_latex, parse_expr]: 29 | try: 30 | return f(s) 31 | except: 32 | pass 33 | return s 34 | 35 | a = _parse(a) 36 | b = _parse(b) 37 | 38 | try: 39 | if simplify(a - b) == 0: 40 | return True 41 | except: 42 | pass 43 | 44 | try: 45 | if isclose(N(a), N(b), rel_tol=1e-3): 46 | return True 47 | except: 48 | pass 49 | return False 50 | 51 | 52 | def symbolic_equal_process(a, b, output_queue): 53 | result = symbolic_equal(a, b) 54 | output_queue.put(result) 55 | 56 | 57 | def call_with_timeout(func, *args, timeout=10, **kwargs): 58 | output_queue = multiprocessing.Queue() 59 | process_args = args + (output_queue,) 60 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 61 | process.start() 62 | process.join(timeout) 63 | 64 | if process.is_alive(): 65 | process.terminate() 66 | process.join() 67 | # return "time out" 68 | return False 69 | return output_queue.get() 70 | 71 | 72 | def math_equal( 73 | prediction: Union[bool, float, str], 74 | reference: Union[float, str], 75 | include_percentage: bool = True, 76 | is_close: bool = True, 77 | timeout: bool = True, 78 | ) -> bool: 79 | """ 80 | Exact match of math if and only if: 81 | 1. numerical equal: both can convert to float and are equal 82 | 2. symbolic equal: both can convert to sympy expression and are equal 83 | """ 84 | try: # 1. numerical equal 85 | if is_digit(prediction) and is_digit(reference): 86 | prediction = float(str(prediction).replace(",", "")) 87 | reference = float(str(reference).replace(",", "")) 88 | # number questions 89 | if include_percentage: 90 | gt_result = [reference / 100, reference, reference * 100] 91 | else: 92 | gt_result = [reference] 93 | for item in gt_result: 94 | try: 95 | if is_close: 96 | if isclose(item, prediction, rel_tol=1e-4): 97 | return True 98 | else: 99 | if item == prediction: 100 | return True 101 | except Exception: 102 | continue 103 | return False 104 | except: 105 | pass 106 | if not prediction and prediction not in [0, False]: 107 | return False 108 | 109 | # 2. symbolic equal 110 | reference = str(reference).strip() 111 | prediction = str(prediction).strip() 112 | 113 | ## deal with [], (), {} 114 | pred_str, ref_str = prediction, reference 115 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( 116 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") 117 | ): 118 | pred_str = pred_str.strip("[]()") 119 | ref_str = ref_str.strip("[]()") 120 | for s in ["{", "}", "(", ")"]: 121 | ref_str = ref_str.replace(s, "") 122 | pred_str = pred_str.replace(s, "") 123 | if pred_str == ref_str: 124 | return True 125 | 126 | ## [a, b] vs. [c, d], return a==c and b==d 127 | if ( 128 | (prediction.startswith("[") and prediction.endswith("]")) 129 | and (reference.startswith("[") and reference.endswith("]")) 130 | or (prediction.startswith("(") and prediction.endswith(")")) 131 | and (reference.startswith("(") and reference.endswith(")")) 132 | ): 133 | pred_parts = prediction[1:-1].split(",") 134 | ref_parts = reference[1:-1].split(",") 135 | if len(pred_parts) == len(ref_parts): 136 | if all( 137 | [math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))] 138 | ): 139 | return True 140 | 141 | # symbolic equal with sympy 142 | if timeout: 143 | if call_with_timeout(symbolic_equal_process, prediction, reference): 144 | return True 145 | # elif tmp == 'time out': 146 | # return "time out" 147 | else: 148 | if symbolic_equal(prediction, reference): 149 | return True 150 | 151 | return False 152 | 153 | 154 | def _fix_a_slash_b(string): 155 | if len(string.split("/")) != 2: 156 | return string 157 | a = string.split("/")[0] 158 | b = string.split("/")[1] 159 | try: 160 | if "sqrt" not in a: 161 | a = int(a) 162 | if "sqrt" not in b: 163 | b = int(b) 164 | assert string == "{}/{}".format(a, b) 165 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 166 | return new_string 167 | except: 168 | return string 169 | 170 | 171 | def _fix_fracs(string): 172 | substrs = string.split("\\frac") 173 | new_str = substrs[0] 174 | if len(substrs) > 1: 175 | substrs = substrs[1:] 176 | for substr in substrs: 177 | new_str += "\\frac" 178 | if len(substr) > 0 and substr[0] == "{": 179 | new_str += substr 180 | else: 181 | try: 182 | assert len(substr) >= 2 183 | except: 184 | return string 185 | a = substr[0] 186 | b = substr[1] 187 | if b != "{": 188 | if len(substr) > 2: 189 | post_substr = substr[2:] 190 | new_str += "{" + a + "}{" + b + "}" + post_substr 191 | else: 192 | new_str += "{" + a + "}{" + b + "}" 193 | else: 194 | if len(substr) > 2: 195 | post_substr = substr[2:] 196 | new_str += "{" + a + "}" + b + post_substr 197 | else: 198 | new_str += "{" + a + "}" + b 199 | string = new_str 200 | return string 201 | 202 | 203 | def _fix_sqrt(string): 204 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 205 | return _string 206 | 207 | 208 | def strip_string(string): 209 | string = str(string).strip() 210 | # linebreaks 211 | string = string.replace("\n", "") 212 | 213 | # right "." 214 | string = string.rstrip(".") 215 | 216 | # remove inverse spaces 217 | string = string.replace("\\!", "") 218 | string = string.replace("\\ ", "") 219 | 220 | # replace \\ with \ 221 | string = string.replace("\\\\", "\\") 222 | string = string.replace("\\\\", "\\") 223 | 224 | # replace tfrac and dfrac with frac 225 | string = string.replace("tfrac", "frac") 226 | string = string.replace("dfrac", "frac") 227 | 228 | # remove \left and \right 229 | string = string.replace("\\left", "") 230 | string = string.replace("\\right", "") 231 | 232 | # Remove unit: miles, dollars if after is not none 233 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 234 | if _string != "" and _string != string: 235 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 236 | string = _string 237 | 238 | # Remove circ (degrees) 239 | string = string.replace("^{\\circ}", "") 240 | string = string.replace("^\\circ", "") 241 | 242 | # remove dollar signs 243 | string = string.replace("\\$", "") 244 | string = string.replace("$", "") 245 | 246 | string = string.replace("\\text", "") 247 | string = string.replace("x\\in", "") 248 | 249 | # remove percentage 250 | string = string.replace("\\%", "") 251 | string = string.replace("\%", "") 252 | string = string.replace("%", "") 253 | 254 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 255 | string = string.replace(" .", " 0.") 256 | string = string.replace("{.", "{0.") 257 | 258 | # cdot 259 | string = string.replace("\\cdot", "") 260 | 261 | # inf 262 | string = string.replace("infinity", "\\infty") 263 | if "\\infty" not in string: 264 | string = string.replace("inf", "\\infty") 265 | string = string.replace("+\\inity", "\\infty") 266 | 267 | # and 268 | string = string.replace("and", "") 269 | string = string.replace("\\mathbf", "") 270 | 271 | # use regex to remove \mbox{...} 272 | string = re.sub(r"\\mbox{.*?}", "", string) 273 | 274 | # quote 275 | string.replace("'", "") 276 | string.replace('"', "") 277 | 278 | # i, j 279 | if "j" in string and "i" not in string: 280 | string = string.replace("j", "i") 281 | 282 | # replace a.000b where b is not number or b is end, with ab, use regex 283 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) 284 | string = re.sub(r"(\d+)\.0+$", r"\1", string) 285 | 286 | # if empty, return empty string 287 | if len(string) == 0: 288 | return string 289 | if string[0] == ".": 290 | string = "0" + string 291 | 292 | # to consider: get rid of e.g. "k = " or "q = " at beginning 293 | if len(string.split("=")) == 2: 294 | if len(string.split("=")[0]) <= 2: 295 | string = string.split("=")[1] 296 | 297 | string = _fix_sqrt(string) 298 | string = string.replace(" ", "") 299 | 300 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 301 | string = _fix_fracs(string) 302 | 303 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 304 | string = _fix_a_slash_b(string) 305 | 306 | return string 307 | 308 | 309 | import re 310 | 311 | 312 | def check1(a, b): 313 | 314 | try: 315 | a = parse_latex(a) 316 | b = parse_latex(b) 317 | if abs(float(a) - float(b)) < 0.01: 318 | return True 319 | except: 320 | pass 321 | 322 | return False 323 | 324 | 325 | print(check1("0.25", "\\frac{1}{4}")) 326 | 327 | 328 | def parse_ground_truth(example): 329 | cnt = 0 330 | if example["type"] in [ 331 | "gpt-3.5-turbo", 332 | "MATH_FOBAR", 333 | "GSM_Rephrased", 334 | "MATH_SV", 335 | "MATH_Rephrased", 336 | "GSM_SV", 337 | "GSM_FOBAR", 338 | ]: 339 | pattern = r"The answer is: (.*?)$" 340 | match = re.search(pattern, example["solution"]) 341 | 342 | if match: 343 | gt_ans = match.group(1) 344 | # print("The prediction is:", prediction) 345 | else: 346 | print("No prediction found for gpt-3.5-turbo.") 347 | 348 | check_ans = extract_answer(example["solution"]) 349 | if check_ans != gt_ans: 350 | # if math_equal(gt_ans, check_ans): 351 | # pass 352 | if check1(gt_ans, check_ans): 353 | pass 354 | else: 355 | # print(example['type'], gt_ans, check_ans, "\n") 356 | cnt += 1 357 | gt_ans = "delete" 358 | elif example["type"] == "gsm8k": 359 | gt_ans = example["solution"].split("####")[-1] 360 | 361 | elif example["type"] == "math": 362 | gt_ans = extract_answer(example["solution"]) 363 | 364 | else: 365 | print(example["type"]) 366 | raise NotImplementedError() 367 | gt_ans = strip_string(gt_ans) 368 | # if gt_ans.startswith('\frac'): 369 | # gt_ans = '\' + gt_ans 370 | return {"gt": gt_ans} 371 | 372 | 373 | def extract_answer(pred_str): 374 | if "boxed" in pred_str: 375 | ans = pred_str.split("boxed")[-1] 376 | if len(ans) == 0: 377 | return "" 378 | elif ans[0] == "{": 379 | stack = 1 380 | a = "" 381 | for c in ans[1:]: 382 | if c == "{": 383 | stack += 1 384 | a += c 385 | elif c == "}": 386 | stack -= 1 387 | if stack == 0: 388 | break 389 | a += c 390 | else: 391 | a += c 392 | else: 393 | a = ans.split("$")[0].strip() 394 | pred = a 395 | elif "he answer is" in pred_str: 396 | pred = pred_str.split("he answer is")[-1].strip() 397 | elif extract_program_output(pred_str) != "": 398 | # fall back to program 399 | pred = extract_program_output(pred_str) 400 | else: # use the last number 401 | pattern = "-?\d*\.?\d+" 402 | pred = re.findall(pattern, pred_str.replace(",", "")) 403 | if len(pred) >= 1: 404 | pred = pred[-1] 405 | else: 406 | pred = "" 407 | 408 | # multiple line 409 | pred = pred.split("\n")[0] 410 | if pred != "" and pred[0] == ":": 411 | pred = pred[1:] 412 | if pred != "" and pred[-1] == ".": 413 | pred = pred[:-1] 414 | if pred != "" and pred[-1] == "/": 415 | pred = pred[:-1] 416 | pred = strip_string(pred) 417 | return pred 418 | 419 | 420 | from sympy.parsing.latex import parse_latex 421 | 422 | # Example LaTeX string 423 | 424 | z = 0 425 | # ds = load_dataset('1231czx/7b_sft_510k_3epoch_gen_data_iter1', split='train') 426 | 427 | # for sample in ds: 428 | # a, b = parse_ground_truth(sample) 429 | # z += b 430 | # print(z) 431 | 432 | ds_new = ds.map(parse_ground_truth, num_proc=32) 433 | ds_new = ds_new.filter(lambda example: example["gt"] != "delete") 434 | 435 | print("#######################################\n", "After delete the prompts that cannot be verified") 436 | print(ds_new[0]) 437 | print(ds, ds_new) 438 | 439 | 440 | ########################################### 441 | 442 | import re 443 | 444 | 445 | def parse_conversation(example): 446 | # Split the data into turns based on the start_of_turn and end_of_turn markers 447 | data = example["my_solu"][0] 448 | turns = re.split(r"|", data) 449 | 450 | # Clean and filter out empty entries 451 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")] 452 | 453 | # Create a list to hold the parsed conversation in the desired format 454 | conversation = [] 455 | 456 | # Process each turn, assigning the correct role and content 457 | for turn in turns: 458 | if turn.startswith("user\n"): 459 | # Extract content after the role identifier 460 | content = turn[5:].strip() 461 | conversation.append({"role": "user", "content": content}) 462 | elif turn.startswith("model\n"): 463 | content = turn[6:].strip() 464 | conversation.append({"role": "assistant", "content": content}) 465 | 466 | return {"messages": conversation} 467 | 468 | 469 | ds_new = ds_new.map(parse_conversation, num_proc=32) 470 | 471 | 472 | 473 | def is_correct(example): 474 | is_correct_label = False 475 | a = example["pred"][0] 476 | b = example["gt"] 477 | if a == b: 478 | is_correct_label = True 479 | return {"is_correct": is_correct_label} 480 | try: 481 | if abs(float(a) - float(b)) < 0.0001: 482 | is_correct_label = True 483 | return {"is_correct": is_correct_label} 484 | except: 485 | pass 486 | 487 | try: 488 | 489 | if math_equal(a, b): 490 | # if z == True: 491 | is_correct_label = True 492 | 493 | return {"is_correct": is_correct_label} 494 | 495 | # elif z == 'time out': 496 | # print("time out") 497 | # return {"is_correct": 'time out'} 498 | 499 | # signal.alarm(0) 500 | # except TimeoutException: 501 | # pass 502 | # if math_equal(a, b): 503 | # is_correct_label = True 504 | except: 505 | pass 506 | 507 | """ 508 | try: 509 | if check1(a, b): 510 | is_correct_label = True 511 | return {"is_correct": is_correct_label} 512 | except: 513 | pass 514 | """ 515 | # try: 516 | # if math_equal(a, b): 517 | # is_correct_label = True 518 | # except: 519 | # pass 520 | 521 | # print(example['type']) 522 | # raise NotImplementedError() 523 | # gt_ans = strip_string(gt_ans) 524 | return {"is_correct": is_correct_label} 525 | 526 | 527 | def filter_example1(example): 528 | old_messages = example["messages"] 529 | 530 | if len(old_messages) < 4: 531 | return False 532 | 533 | if len(old_messages) % 2 != 0: 534 | return False 535 | 536 | all_mes_len = len(old_messages) 537 | if example["is_correct"] and "error" in old_messages[-2]["content"].lower(): 538 | return False 539 | 540 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower(): 541 | return False 542 | 543 | k = 0 544 | 545 | for mes in old_messages: 546 | if k % 2 != 0 and k < all_mes_len - 1: 547 | if "python" not in mes["content"]: 548 | return False 549 | k += 1 550 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower(): 551 | return False 552 | if mes["content"] == "```output\nExecution error: \n```": 553 | return False 554 | if "```output\n[]" in mes["content"]: 555 | # print(mes['content']) 556 | return False 557 | 558 | return True 559 | 560 | 561 | from transformers import AutoTokenizer, HfArgumentParser 562 | 563 | # ds_new = ds_new.filter(filter_example1, num_proc=32) 564 | 565 | print("############### Answer check") 566 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it") 567 | 568 | 569 | 570 | def filter_too_long_pred(example): 571 | try: 572 | if len(example["pred"][0]) > 20: 573 | return False 574 | except: 575 | return False 576 | old_messages = example["messages"] 577 | if len(old_messages) < 4: 578 | return False 579 | 580 | if len(old_messages) % 2 != 0: 581 | return False 582 | 583 | all_mes_len = len(old_messages) 584 | 585 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower(): 586 | return False 587 | 588 | k = 0 589 | 590 | for mes in old_messages: 591 | if k % 2 != 0 and k < all_mes_len - 1: 592 | if "python" not in mes["content"]: 593 | return False 594 | k += 1 595 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower(): 596 | return False 597 | if mes["content"] == "```output\nExecution error: \n```": 598 | return False 599 | if "```output\n[]" in mes["content"]: 600 | # print(mes['content']) 601 | return False 602 | 603 | z = len(tokenizer.apply_chat_template(old_messages, tokenize=True)) 604 | if z > 2048: 605 | return False 606 | 607 | # b = len(tokenizer.apply_chat_template(ds_test[i]['rejected'], tokenize=True)) 608 | 609 | return True 610 | 611 | 612 | ds_new = ds_new.filter(filter_too_long_pred, num_proc=32) 613 | 614 | ds_new = ds_new.map(is_correct, num_proc=32) 615 | 616 | 617 | ds_new = ds_new.filter(filter_example1, num_proc=32) 618 | 619 | ds_win = ds_new.filter(lambda example: example["is_correct"] == True) 620 | ds_lose = ds_new.filter(lambda example: example["is_correct"] == False) 621 | 622 | print(len(ds_win) + len(ds_lose), len(ds_new)) 623 | 624 | 625 | # print("I have win ", len(ds_win), " and lose ", len(ds_lose)) 626 | def filter_win(example): 627 | old_messages = example["messages"] 628 | 629 | if "error" in old_messages[-2]["content"].lower(): 630 | return False 631 | if "none" in old_messages[-2]["content"].lower(): 632 | return False 633 | 634 | return True 635 | 636 | 637 | print("I have win ", len(ds_win), " and lose ", len(ds_lose)) 638 | ds_win = ds_win.filter(filter_win, num_proc=32) 639 | print("I have win ", len(ds_win), " and lose ", len(ds_lose)) 640 | 641 | 642 | win_ret = defaultdict(list) 643 | lose_ret = defaultdict(list) 644 | 645 | 646 | for sample in ds_win: 647 | idx = sample["idx"] 648 | win_ret[idx].append(sample) 649 | 650 | 651 | for sample in ds_lose: 652 | idx = sample["idx"] 653 | lose_ret[idx].append(sample) 654 | 655 | 656 | new_win_ret = defaultdict(list) 657 | 658 | cnt_win = 0 659 | 660 | for key, value in win_ret.items(): 661 | if len(win_ret[key]) == 0: 662 | continue 663 | j = key 664 | all_samples = win_ret[j] 665 | all_texts = [] 666 | new_samples = [] 667 | for ins in all_samples: 668 | if ins["messages"][1]["content"] in all_texts: 669 | continue 670 | all_texts.append(ins["messages"][1]["content"]) 671 | new_samples.append(ins) 672 | cnt_win += 1 673 | 674 | new_win_ret[j].extend(new_samples) 675 | 676 | 677 | cnt_lose = 0 678 | new_lose_ret = defaultdict(list) 679 | 680 | 681 | for key, value in lose_ret.items(): 682 | if len(lose_ret[key]) == 0: 683 | continue 684 | j = key 685 | all_samples = lose_ret[j] 686 | all_texts = [] 687 | new_samples = [] 688 | for ins in all_samples: 689 | if ins["messages"][1]["content"] in all_texts: 690 | continue 691 | all_texts.append(ins["messages"][1]["content"]) 692 | new_samples.append(ins) 693 | cnt_lose += 1 694 | 695 | new_lose_ret[j].extend(new_samples) 696 | 697 | print("Before get final pairs, I have win and lose", cnt_win, cnt_lose) 698 | import random 699 | 700 | import numpy as np 701 | 702 | all_comp = [] 703 | all_sft = [] 704 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys()) 705 | all_keys = list(set(all_keys)) 706 | import itertools 707 | 708 | """ 709 | for ins in new_win_ret[0]: 710 | print(is_correct(ins), ins['pred'][0]) 711 | 712 | for ins in new_lose_ret[0]: 713 | print(is_correct(ins), ins['pred'][0], ins) 714 | print(check1(ins['gt'], ins['pred'][0])) 715 | print(check1('\\frac{1}{4}', '0.25')) 716 | """ 717 | cnt_comp = 0 718 | # N = 1 719 | for j in all_keys: 720 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0: 721 | cnt_comp += N_pair 722 | else: 723 | continue 724 | all_pos = new_win_ret[j] 725 | all_neg = new_lose_ret[j] 726 | random.shuffle(all_pos) 727 | random.shuffle(all_neg) 728 | if len(all_pos) > N_pair and len(all_neg) > N_pair: 729 | for k in range(N_pair): 730 | all_comp.append( 731 | { 732 | "gt": all_pos[k]["gt"], 733 | "rej": all_neg[k]["pred"], 734 | "chosen": all_pos[k]["messages"], 735 | "rejected": all_neg[k]["messages"], 736 | } 737 | ) 738 | all_sft.append({"messages": all_pos[k]["messages"]}) 739 | continue 740 | 741 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg))))) 742 | 743 | random.shuffle(combinations) 744 | for k in range(np.min([len(combinations), N_pair])): 745 | all_comp.append( 746 | { 747 | "gt": all_pos[combinations[k][0]]["gt"], 748 | "chosen": all_pos[combinations[k][0]]["messages"], 749 | "rejected": all_neg[combinations[k][1]]["messages"], 750 | "rej": all_neg[combinations[k][1]]["pred"], 751 | } 752 | ) 753 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]}) 754 | 755 | 756 | # print(all_comp[0]) 757 | output_eval_dataset = {} 758 | output_eval_dataset["type"] = "text_only" 759 | output_eval_dataset["instances"] = all_comp 760 | print("I collect ", len(all_comp), "samples", len(all_sft)) 761 | 762 | import json 763 | 764 | with open("tmp_comp.json", "w", encoding="utf8") as f: 765 | json.dump(output_eval_dataset, f, ensure_ascii=False) 766 | 767 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances") 768 | ds_comp.push_to_hub(data_comp) 769 | 770 | output_eval_dataset = {} 771 | output_eval_dataset["type"] = "text_only" 772 | output_eval_dataset["instances"] = all_sft 773 | print("I collect ", len(all_comp), "samples", len(all_sft)) 774 | 775 | import json 776 | 777 | with open("tmp_sft.json", "w", encoding="utf8") as f: 778 | json.dump(output_eval_dataset, f, ensure_ascii=False) 779 | 780 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances") 781 | ds_sft.push_to_hub(data_sft) 782 | 783 | # ds_new = ds_new.remove_columns("idx") 784 | 785 | # def add_index(example, idx): 786 | # # Add the current index to the example under a new field 'index' 787 | # example['idx'] = idx 788 | # return example 789 | 790 | # ds_new = ds_new.map(add_index, with_indices=True) 791 | # ds_new.push_to_hub("1231czx/prompts_80K_with_original_MATH_GSM8K_iter1") 792 | 793 | ################################### 794 | 795 | N_pair = 3 796 | data_comp = "1231czx/7B_iter1_dpo_N3_random_pair" 797 | data_sft = "1231czx/7B_iter1_sft_N3" 798 | 799 | all_comp = [] 800 | all_sft = [] 801 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys()) 802 | all_keys = list(set(all_keys)) 803 | import itertools 804 | 805 | """ 806 | for ins in new_win_ret[0]: 807 | print(is_correct(ins), ins['pred'][0]) 808 | 809 | for ins in new_lose_ret[0]: 810 | print(is_correct(ins), ins['pred'][0], ins) 811 | print(check1(ins['gt'], ins['pred'][0])) 812 | print(check1('\\frac{1}{4}', '0.25')) 813 | """ 814 | cnt_comp = 0 815 | # N = 1 816 | for j in all_keys: 817 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0: 818 | cnt_comp += N_pair 819 | else: 820 | continue 821 | all_pos = new_win_ret[j] 822 | all_neg = new_lose_ret[j] 823 | random.shuffle(all_pos) 824 | random.shuffle(all_neg) 825 | if len(all_pos) > N_pair and len(all_neg) > N_pair: 826 | for k in range(N_pair): 827 | all_comp.append( 828 | { 829 | "gt": all_pos[k]["gt"], 830 | "rej": all_neg[k]["pred"], 831 | "chosen": all_pos[k]["messages"], 832 | "rejected": all_neg[k]["messages"], 833 | } 834 | ) 835 | all_sft.append({"messages": all_pos[k]["messages"]}) 836 | continue 837 | 838 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg))))) 839 | 840 | random.shuffle(combinations) 841 | for k in range(np.min([len(combinations), N_pair])): 842 | all_comp.append( 843 | { 844 | "gt": all_pos[combinations[k][0]]["gt"], 845 | "chosen": all_pos[combinations[k][0]]["messages"], 846 | "rejected": all_neg[combinations[k][1]]["messages"], 847 | "rej": all_neg[combinations[k][1]]["pred"], 848 | } 849 | ) 850 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]}) 851 | 852 | 853 | # print(all_comp[0]) 854 | output_eval_dataset = {} 855 | output_eval_dataset["type"] = "text_only" 856 | output_eval_dataset["instances"] = all_comp 857 | print("I collect ", len(all_comp), "samples", len(all_sft)) 858 | 859 | import json 860 | 861 | with open("tmp_comp.json", "w", encoding="utf8") as f: 862 | json.dump(output_eval_dataset, f, ensure_ascii=False) 863 | 864 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances") 865 | ds_comp.push_to_hub(data_comp) 866 | 867 | output_eval_dataset = {} 868 | output_eval_dataset["type"] = "text_only" 869 | output_eval_dataset["instances"] = all_sft 870 | print("I collect ", len(all_comp), "samples", len(all_sft)) 871 | 872 | import json 873 | 874 | with open("tmp_sft.json", "w", encoding="utf8") as f: 875 | json.dump(output_eval_dataset, f, ensure_ascii=False) 876 | 877 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances") 878 | ds_sft.push_to_hub(data_sft) 879 | 880 | 881 | ########################## 882 | 883 | N_pair = 8 884 | data_comp = "1231czx/7B_iter1_dpo_N8_random_pair" 885 | data_sft = "1231czx/7B_iter1_sft_N8" 886 | 887 | all_comp = [] 888 | all_sft = [] 889 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys()) 890 | all_keys = list(set(all_keys)) 891 | import itertools 892 | 893 | """ 894 | for ins in new_win_ret[0]: 895 | print(is_correct(ins), ins['pred'][0]) 896 | 897 | for ins in new_lose_ret[0]: 898 | print(is_correct(ins), ins['pred'][0], ins) 899 | print(check1(ins['gt'], ins['pred'][0])) 900 | print(check1('\\frac{1}{4}', '0.25')) 901 | """ 902 | cnt_comp = 0 903 | # N = 1 904 | for j in all_keys: 905 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0: 906 | cnt_comp += N_pair 907 | else: 908 | continue 909 | all_pos = new_win_ret[j] 910 | all_neg = new_lose_ret[j] 911 | random.shuffle(all_pos) 912 | random.shuffle(all_neg) 913 | if len(all_pos) > N_pair and len(all_neg) > N_pair: 914 | for k in range(N_pair): 915 | all_comp.append( 916 | { 917 | "gt": all_pos[k]["gt"], 918 | "rej": all_neg[k]["pred"], 919 | "chosen": all_pos[k]["messages"], 920 | "rejected": all_neg[k]["messages"], 921 | } 922 | ) 923 | all_sft.append({"messages": all_pos[k]["messages"]}) 924 | continue 925 | 926 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg))))) 927 | 928 | random.shuffle(combinations) 929 | for k in range(np.min([len(combinations), N_pair])): 930 | all_comp.append( 931 | { 932 | "gt": all_pos[combinations[k][0]]["gt"], 933 | "chosen": all_pos[combinations[k][0]]["messages"], 934 | "rejected": all_neg[combinations[k][1]]["messages"], 935 | "rej": all_neg[combinations[k][1]]["pred"], 936 | } 937 | ) 938 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]}) 939 | 940 | 941 | # print(all_comp[0]) 942 | output_eval_dataset = {} 943 | output_eval_dataset["type"] = "text_only" 944 | output_eval_dataset["instances"] = all_comp 945 | print("I collect ", len(all_comp), "samples", len(all_sft)) 946 | 947 | import json 948 | 949 | with open("tmp_comp.json", "w", encoding="utf8") as f: 950 | json.dump(output_eval_dataset, f, ensure_ascii=False) 951 | 952 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances") 953 | ds_comp.push_to_hub(data_comp) 954 | 955 | output_eval_dataset = {} 956 | output_eval_dataset["type"] = "text_only" 957 | output_eval_dataset["instances"] = all_sft 958 | print("I collect ", len(all_comp), "samples", len(all_sft)) 959 | 960 | import json 961 | 962 | with open("tmp_sft.json", "w", encoding="utf8") as f: 963 | json.dump(output_eval_dataset, f, ensure_ascii=False) 964 | 965 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances") 966 | ds_sft.push_to_hub(data_sft) 967 | -------------------------------------------------------------------------------- /useful_codes/interpolate_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #!/usr/bin/env python 3 | # coding=utf-8 4 | # We provide a script to merge different checkpoints of the model 5 | # See a detailed study in Mitigating the Alignment Tax of RLHF, https://arxiv.org/abs/2309.06256 6 | 7 | import json 8 | import os 9 | import sys 10 | from transformers import HfArgumentParser, AutoModelForCausalLM 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description="merge model checkpoints") 14 | parser.add_argument("--base_model", type=str, required=True, help="dir of base model") 15 | parser.add_argument("--new_model", type=str, required=True, help="dir of new model") 16 | parser.add_argument("--output_dir", type=str, required=True, help="output dir") 17 | parser.add_argument("--ratio", type=float, required=True, help="the ratio of the new model") 18 | 19 | args = parser.parse_args() 20 | os.makedirs(args.output_dir, exist_ok=True) 21 | print(f"Base model: {args.base_model}") 22 | print(f"New model: {args.new_model}") 23 | print(f"Output directory: {args.output_dir}") 24 | 25 | new_dir = args.new_model 26 | base_dir = args.base_model 27 | weight_ensamble_save_path = args.output_dir 28 | weight_ensamble_ratios = args.ratio 29 | 30 | # Get the paths and ratios of weight-ensamble models. 31 | # args.ratio * new_model + (1 - args.ratio) * base_model 32 | 33 | weight_ensamble_names_paths = [new_dir, base_dir] 34 | weight_ensamble_ratios.append(1 - weight_ensamble_ratios[0]) 35 | assert len(weight_ensamble_ratios) == 2, 'Only 2 merge is supported.' 36 | print('Model Paths:', weight_ensamble_names_paths) 37 | print('Model Ratio:', weight_ensamble_ratios) 38 | 39 | base_model = None 40 | backend_models = [] 41 | for model_path in weight_ensamble_names_paths: 42 | #model_args.model_name_or_path = model_path 43 | print('loading:', model_path) 44 | model = AutoModelForCausalLM.from_pretrained(model_path)#, torch_dtype=torch.bfloat16) 45 | backend_models.append(model.to('cpu')) 46 | if base_model is None: 47 | base_model = model 48 | print('Finish load:', model_path) 49 | base_backend_model = backend_models[0] 50 | print('Finish load All:', base_backend_model) 51 | 52 | updated_state_dict = {} 53 | for key in base_backend_model.state_dict(): 54 | ensambled_state_dicts = [ratio * backend_model.state_dict()[key] for backend_model, ratio in zip(backend_models, weight_ensamble_ratios)] 55 | updated_state_dict[key] = sum(ensambled_state_dicts) 56 | 57 | base_backend_model.load_state_dict(updated_state_dict) 58 | base_model.save_pretrained(weight_ensamble_save_path) 59 | -------------------------------------------------------------------------------- /useful_codes/merge.py: -------------------------------------------------------------------------------- 1 | # This script finds all the files ending with jsonl and merge them into one file 2 | import os 3 | from datasets import load_dataset 4 | import json 5 | 6 | # The folders to load data 7 | all_folder_path = [ 8 | './7b_sft3epoch_gen_data/iter1', 9 | './7b_sft1epoch_gen_data/iter1' 10 | ] 11 | 12 | output_dir='all_math.json' 13 | 14 | all_data = [] 15 | for folder_path in all_folder_path: 16 | jsonl_files = [folder_path + '/' + f for f in os.listdir(folder_path) if f.endswith('.jsonl')] 17 | for dir_ in jsonl_files: 18 | ds_test = load_dataset('json', data_files=dir_, split='train') 19 | for sample in ds_test: 20 | all_data.append(sample) 21 | 22 | output_eval_dataset = {} 23 | output_eval_dataset["type"] = "text_only" 24 | output_eval_dataset["instances"] = all_data 25 | print("I collect ", len(all_data), "samples") 26 | 27 | with open(output_dir, "w", encoding="utf8") as f: 28 | json.dump(output_eval_dataset, f, ensure_ascii=False) 29 | 30 | # You can also upload the dataset to the huggingface, which is particularly useful when collecting data with multiple machines 31 | """ 32 | output_dir = "xxx" 33 | dict_data = { 34 | "idx": [d['idx'] for d in all_data], 35 | "gt": [d['gt'] for d in all_data], 36 | "level": [d['level'] for d in all_data], 37 | "type": [d['type'] for d in all_data], 38 | "my_solu": [d['my_solu'] for d in all_data], 39 | "pred": [d['pred'] for d in all_data], 40 | } 41 | 42 | dataset = Dataset.from_dict(dict_data) 43 | DatasetDict({'train': dataset}).push_to_hub(output_dir) 44 | """ 45 | -------------------------------------------------------------------------------- /useful_codes/push_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | name = 'model_dir' 5 | output_name = 'output_name' 6 | tokenizer_name = name 7 | 8 | model = AutoModelForCausalLM.from_pretrained( 9 | name, 10 | torch_dtype=torch.bfloat16, 11 | ) 12 | 13 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 14 | 15 | model.push_to_hub(output_name) 16 | tokenizer.push_to_hub(output_name) 17 | -------------------------------------------------------------------------------- /useful_codes/set_padding_token.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | name = 'mistralai/Mistral-7B-v0.3' 5 | tokenizer_name = name 6 | 7 | model = AutoModelForCausalLM.from_pretrained( 8 | name, 9 | torch_dtype=torch.bfloat16, 10 | ) 11 | 12 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 13 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 14 | model.config.pad_token_id = tokenizer.pad_token_id 15 | 16 | model.resize_token_embeddings(len(tokenizer)) 17 | 18 | model.save_pretrained("output_dir") 19 | tokenizer.save_pretrained("output_dir") 20 | --------------------------------------------------------------------------------