├── .DS_Store
├── README.md
├── SFT
├── README.md
├── code-gemma-7B-it.yaml
├── gemma-2-9b-it.yaml
├── gemma-7b-it.yaml
├── mistral-v0.3.yaml
└── sft_configs
│ ├── sft_ds2.json
│ └── sft_ds3.json
├── alignment_algorithms
├── README.md
├── dpo.py
├── dpo_trainer.py
├── kto_trainer.py
├── run_dpo.py
├── run_dpo.sh
├── run_kto.py
├── run_kto.sh
└── training_configs
│ ├── zero2_pf.yaml
│ └── zero3_pf.yaml
├── assets
└── main_result.png
├── inference
├── .DS_Store
├── README.md
├── data
│ ├── gsm8k
│ │ ├── test.jsonl
│ │ └── train.jsonl
│ └── math
│ │ ├── test.jsonl
│ │ └── train.jsonl
├── eval
│ ├── evaluate.py
│ └── grader.py
├── infer_data
│ ├── annotate_data.py
│ ├── get_dpo_dataset.py
│ └── infer_eval.py
├── scripts
│ ├── eval.sh
│ ├── infer.sh
│ ├── iter_infer_to_collect_data.sh
│ └── register_server.sh
└── utils
│ ├── annotate_data.py
│ ├── data_loader.py
│ ├── filter_data.py
│ ├── parser.py
│ ├── python_executor.py
│ └── utils.py
└── useful_codes
├── annotate_data.py
├── interpolate_model.py
├── merge.py
├── push_model.py
└── set_padding_token.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Building Math Agents with Multi-Turn Iterative Preference Learning
4 |
5 |
6 |
7 | TL;DL: this is the repo for "Building Math Agents with Multi-Turn Iterative Preference Learning"
8 |
9 |
10 | We consider the math problem solving with python interpreter, which means that the model can write a python code and ask the external environmnet to execute and receive the excutaion result, before the LLM makes its next decision.
11 |
12 |
13 |
14 |
15 | Figure 1: Main evaluation results on the MATH and GSK8K datasets.
16 |
17 |
18 | ## Structure
19 |
20 | The main pipeline is divided into three steps:
21 |
22 |
23 | - [`SFT`](./SFT/) to train the SFT model.
24 | - [`Inference`](./inference/) to generate new data and evaluate the model.
25 | - [`Multi-turn Alignment Algorithms`](./alignment_algorithms/) to conduct the multi-turn DPO/KTO training.
26 |
27 |
28 | It is recommended to have three separate environments for **sft**, **inference**, and **alignment_train**. Please refer to the corresponding part of this project for the detailed installation instruction.
29 |
30 | ## Collection
31 |
32 | - [SFT Dataset: RLHF4MATH/SFT_510K](https://huggingface.co/datasets/RLHF4MATH/SFT_510K), which is a subset of nvidia/OpenMathInstruct-1
33 | - [Prompt](https://huggingface.co/datasets/RLHF4MATH/prompt_iter1): RLHF4MATH/prompt_iter1, RLHF4MATH/prompt_iter2, RLHF4MATH/prompt_iter3
34 | - [SFT Model: RLHF4MATH/Gemma-7B-it-SFT3epoch](https://huggingface.co/RLHF4MATH/Gemma-7B-it-SFT3epoch)
35 | - [Aligned Model: RLHF4MATH/Gemma-7B-it-M-DPO](https://huggingface.co/RLHF4MATH/Gemma-7B-it-M-DPO)
36 |
37 | ## Acknowledgement
38 |
39 | The authors would like to thank the great open-source communities, including the Huggingface TRL team, Axolotl team, and Tora project for sharing the models, and codes.
40 |
41 | ## Citation
42 |
43 | If you find the content of this repo useful, please consider cite it as follows:
44 |
45 | ```bibtex
46 | @article{xiong2024building,
47 | title={Building Math Agents with Multi-Turn Iterative Preference Learning},
48 | author={Xiong, Wei and Shi, Chengshuai and Shen, Jiaming and Rosenberg, Aviv and Qin, Zhen and Calandriello, Daniele and Khalman, Misha and Joshi, Rishabh and Piot, Bilal and Saleh, Mohammad and others},
49 | journal={arXiv preprint arXiv:2409.02392},
50 | year={2024}
51 | }
52 | ```
53 |
--------------------------------------------------------------------------------
/SFT/README.md:
--------------------------------------------------------------------------------
1 | # SFT for Multi-round Chat
2 |
3 | We conduct supervised fine-tuning here for multi-turn chat. The data is supposed to be processed into the standard format (see [RLHF4MATH/SFT_510K](https://huggingface.co/datasets/RLHF4MATH/SFT_510K)) for an example).
4 |
5 | We formulate the problem into a multi-turn chat format where the ``user'' prompts the assistant (LLM) to solve a math problem and provides the external feedback. An example is as follows.
6 |
7 | - **User**: Nine adults went to a play with seven children. Adult tickets are $11 each and children's tickets are $7 each. How many dollars more did the adults' tickets cost in total than the children's tickets in total?"
8 | - **Assistant**: Let's solve this problem using Python code.\n```python\nadult_price = 11\nchild_price = 7\nadult_count = 9\nchild_count = 7\nadult_ticket_total = adult_price * adult_count\nchild_ticket_total = child_price * child_count\nadult_ticket_total - child_ticket_total\n```
9 | - **User**: ```output\n50\n```
10 | - **Assistant**: Thus the adult's ticket total was \\boxed{50} dollars more than the children's tickets total.
11 |
12 | We will mask all user turns so that they do not contribute to the gradient.
13 |
14 |
15 | ## Installation instructions
16 |
17 | Before starting, please make sure your linux machine has [nvidia-cuda-toolkit](https://developer.nvidia.com/cuda-toolkit) installed.
18 |
19 | ```shell
20 | sudo apt install nvidia-driver-530 # (530 for 12.1+)
21 | wget https://repo.anaconda.com/archive/Anaconda3-2023.07-0-Linux-x86_64.sh
22 | bash Anaconda3-2023.07-0-Linux-x86_64.sh # install conda
23 | conda install nvidia/label/cuda-12.2.0::cuda-nvcc
24 | ```
25 |
26 | Now we set up the python environment.
27 |
28 | ```shell
29 | conda create -n sft python=3.10.9
30 | conda activate sft
31 |
32 | # The test cuda version is 12.1, 12.2. You may need to update the torch version based on your cuda version...
33 | pip3 install torch==2.1.2 torchvision torchaudio
34 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
35 |
36 | ## Get axolotl for general model
37 | git clone https://github.com/OpenAccess-AI-Collective/axolotl
38 | cd axolotl
39 | git checkout 55cc214c767741e83ee7b346e5e13e6c03b7b9fa
40 | pip install -e .
41 |
42 | ## Get FastChat
43 | git clone https://github.com/lm-sys/FastChat.git
44 | cd FastChat
45 | pip install -e .
46 |
47 | pip install deepspeed
48 |
49 | # You also need to install wandb to record the training and log in with the huggingface accout to access Gemma.
50 |
51 | pip install wandb
52 | wandb login
53 |
54 | huggingface-cli login
55 | ```
56 | ## Running the Code
57 |
58 | Running the code with Gemma.
59 |
60 | ```shell
61 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun --nproc_per_node 8 --master_port 20001 -m axolotl.cli.train gemma-7b-it.yaml
62 | ```
63 |
64 | You can also modify the learning rate, batch size, output_path.. with either command or modify the ScriptArguments in the gemma-7b-it.yml
65 |
66 | If you encounter out-of-memory issue. Running the code with Gemma-7b-it with deepspeed stage 3 and gradient checkpoint (set in the config).
67 |
68 | ```shell
69 | CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node 4 --master_port 20001 -m axolotl.cli.train gemma-2b-it.yml --deepspeed ./sft_configs/sft_ds3.json
70 | ```
71 |
72 | **REMARK: note that with deepspeed stage 3, the final mode saving does not work normally. We set the store strategy as epoch so we can store a normal model just before we finish the training for one epoch. If you modify the store stragety, you should set the save_every_steps as the total number of training steps - 1 so that the trainer will save a model for you just before finishing the training.**
73 |
74 |
75 | Finally, for the models without an official padding token (like Mistral), you may need to set the padding token by ../useful_codes/prepare_model.py first.
76 |
77 |
--------------------------------------------------------------------------------
/SFT/code-gemma-7B-it.yaml:
--------------------------------------------------------------------------------
1 | base_model: google/codegemma-1.1-7b-it
2 | model_type: AutoModelForCausalLM
3 | tokenizer_type: AutoTokenizer
4 |
5 | load_in_8bit: false
6 | load_in_4bit: false
7 | strict: false
8 |
9 | datasets:
10 | - path: RLHF4MATH/SFT_510K
11 | conversation: gemma
12 | type: sharegpt.load_ultrachat
13 | split: "train"
14 | train_on_split: "train"
15 |
16 | warmup_steps: 50
17 | val_set_size: 0.0
18 | output_dir: ./models/gemma-7b-it_bz64_lr5e-6_pack
19 | wandb_project: huggingface
20 | #wandb_entity: domain-generalization
21 | wandb_watch:
22 | wandb_name: "codegemma-7b-it_bs64_lr5e-6_sft_pack4096"
23 | #_response_only
24 | wandb_log_model:
25 |
26 | train_on_inputs: false
27 |
28 | save_safetensors: true
29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type
30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared
31 |
32 |
33 | dataset_processes: 48
34 | #torch_compile: true
35 | sequence_len: 4096
36 | sample_packing: true
37 | pad_to_sequence_len: true
38 |
39 | trust_remote_code: True
40 | adapter:
41 | lora_model_dir:
42 |
43 |
44 |
45 | gradient_checkpointing: true
46 |
47 | #warmup_ratio: 0.1
48 | gradient_accumulation_steps: 2
49 | micro_batch_size: 4
50 | num_epochs: 4
51 | optimizer: paged_adamw_32bit
52 | #adiamw_torch_fused
53 | lr_scheduler: cosine
54 | learning_rate: 5.e-6
55 |
56 |
57 | weight_decay: 0.0
58 | max_grad_norm: 1.0
59 |
60 |
61 | group_by_length: false
62 | bf16: auto
63 | fp16: false
64 | tf32: true
65 |
66 | early_stopping_patience:
67 | local_rank:
68 | logging_steps: 2
69 | xformers_attention:
70 | flash_attention: true
71 |
72 |
73 | eval_steps:
74 | eval_table_size:
75 | eval_table_max_new_tokens:
76 | #save_steps: 100
77 | save_strategy: "epoch"
78 | save_total_limit: 4
79 | debug:
80 |
81 |
82 | ddp: #true
83 | deepspeed: #deepspeed/zero1.json # multi-gpu only
84 |
85 | fsdp:
86 | fsdp_config:
87 | special_tokens:
88 |
--------------------------------------------------------------------------------
/SFT/gemma-2-9b-it.yaml:
--------------------------------------------------------------------------------
1 | base_model: google/gemma-2-9b-it
2 | model_type: AutoModelForCausalLM
3 | tokenizer_type: AutoTokenizer
4 |
5 | load_in_8bit: false
6 | load_in_4bit: false
7 | strict: false
8 |
9 | datasets:
10 | - path: RLHF4MATH/SFT_510K
11 | conversation: gemma
12 | type: sharegpt.load_ultrachat
13 | split: "train"
14 | train_on_split: "train"
15 |
16 | warmup_steps: 100
17 | val_set_size: 0.0
18 | output_dir: ./models/gemma_9b_bz64_lr5e6_pack4096
19 | wandb_project: huggingface
20 | #wandb_entity: sft
21 | wandb_watch:
22 | wandb_name: "gemma_9b_bz64_lr5e6_pack4096"
23 | #_response_only
24 | wandb_log_model:
25 |
26 | train_on_inputs: false
27 |
28 | save_safetensors: true
29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type
30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared
31 |
32 |
33 | dataset_processes: 48
34 | #torch_compile: true
35 | sequence_len: 4096
36 | sample_packing: true
37 | pad_to_sequence_len: true
38 |
39 | trust_remote_code: True
40 | adapter:
41 | lora_model_dir:
42 | gradient_checkpointing: true
43 |
44 | #warmup_ratio: 0.1
45 | gradient_accumulation_steps: 8
46 | micro_batch_size: 1
47 | num_epochs: 4
48 | optimizer: paged_adamw_32bit
49 | lr_scheduler: cosine
50 | learning_rate: 5e-6
51 |
52 | weight_decay: 0.0
53 | max_grad_norm: 1.0
54 |
55 |
56 | group_by_length: false
57 | bf16: auto
58 | fp16: false
59 | tf32: true
60 |
61 | early_stopping_patience:
62 | local_rank:
63 | logging_steps: 2
64 | xformers_attention:
65 | flash_attention: true
66 |
67 |
68 | eval_steps:
69 | eval_table_size:
70 | eval_table_max_new_tokens:
71 | save_steps: 9999
72 | save_strategy: "steps"
73 | save_total_limit: 4
74 | debug:
75 |
76 |
77 | ddp: #true
78 | deepspeed: #deepspeed/zero1.json # multi-gpu only
79 |
80 | fsdp:
81 | fsdp_config:
82 | special_tokens:
83 |
--------------------------------------------------------------------------------
/SFT/gemma-7b-it.yaml:
--------------------------------------------------------------------------------
1 | base_model: google/gemma-1.1-7b-it
2 | model_type: AutoModelForCausalLM
3 | tokenizer_type: AutoTokenizer
4 |
5 | load_in_8bit: false
6 | load_in_4bit: false
7 | strict: false
8 |
9 | datasets:
10 | - path: RLHF4MATH/SFT_510K
11 | conversation: gemma
12 | type: sharegpt.load_ultrachat
13 | split: "train"
14 | train_on_split: "train"
15 |
16 | warmup_steps: 50
17 | val_set_size: 0.0
18 | output_dir: ./pm_models/code_gemma-7b-it_bz64_lr5e-6_non_pack
19 | wandb_project: huggingface
20 | #wandb_entity: domain-generalization
21 | wandb_watch:
22 | wandb_name: "codegemma-7b-it_bs64_lr5e-6_sft_pack4096"
23 | #_response_only
24 | wandb_log_model:
25 |
26 | train_on_inputs: false
27 |
28 | save_safetensors: true
29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type
30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared
31 |
32 |
33 | dataset_processes: 48
34 | #torch_compile: true
35 | sequence_len: 4096
36 | sample_packing: true
37 | pad_to_sequence_len: true
38 |
39 | trust_remote_code: True
40 | adapter:
41 | lora_model_dir:
42 |
43 |
44 |
45 | gradient_checkpointing: true
46 |
47 | #warmup_ratio: 0.1
48 | gradient_accumulation_steps: 2
49 | micro_batch_size: 4
50 | num_epochs: 4
51 | optimizer: paged_adamw_32bit
52 | #adiamw_torch_fused
53 | lr_scheduler: cosine
54 | learning_rate: 5.e-6
55 |
56 |
57 | weight_decay: 0.0
58 | max_grad_norm: 1.0
59 |
60 |
61 | group_by_length: false
62 | bf16: auto
63 | fp16: false
64 | tf32: true
65 |
66 | early_stopping_patience:
67 | local_rank:
68 | logging_steps: 2
69 | xformers_attention:
70 | flash_attention: true
71 |
72 |
73 | eval_steps:
74 | eval_table_size:
75 | eval_table_max_new_tokens:
76 | #save_steps: 100
77 | save_strategy: "epoch"
78 | save_total_limit: 4
79 | debug:
80 |
81 |
82 | ddp: #true
83 | deepspeed: #deepspeed/zero1.json # multi-gpu only
84 |
85 | fsdp:
86 | fsdp_config:
87 | special_tokens:
88 |
--------------------------------------------------------------------------------
/SFT/mistral-v0.3.yaml:
--------------------------------------------------------------------------------
1 | base_model: ./models/mistral_with_pad
2 | model_type: AutoModelForCausalLM
3 | tokenizer_type: AutoTokenizer
4 |
5 | load_in_8bit: false
6 | load_in_4bit: false
7 | strict: false
8 |
9 | datasets:
10 | - path: RLHF4MATH/SFT_510K
11 | conversation: mistral
12 | type: sharegpt.load_ultrachat
13 | split: "train"
14 | train_on_split: "train"
15 |
16 | warmup_steps: 50
17 | val_set_size: 0.0
18 | output_dir: ./models/mistral_bz64_1e5_pack4096_1200k
19 | wandb_project: huggingface
20 | #wandb_entity: sft
21 | wandb_watch:
22 | wandb_name: "mistral_bz64_1e5_pack4096_510k"
23 | #_response_only
24 | wandb_log_model:
25 |
26 | train_on_inputs: false
27 |
28 | save_safetensors: true
29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type
30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared
31 |
32 |
33 | dataset_processes: 48
34 | #torch_compile: true
35 | sequence_len: 4096
36 | sample_packing: true
37 | pad_to_sequence_len: true
38 |
39 | trust_remote_code: True
40 | adapter:
41 | lora_model_dir:
42 |
43 |
44 |
45 |
46 | gradient_checkpointing: true
47 |
48 | #warmup_ratio: 0.1
49 | gradient_accumulation_steps: 2
50 | micro_batch_size: 4
51 | num_epochs: 3
52 | optimizer: paged_adamw_32bit
53 | lr_scheduler: cosine
54 | learning_rate: 1.e-5
55 |
56 | weight_decay: 0.0
57 | max_grad_norm: 1.0
58 |
59 |
60 | group_by_length: false
61 | bf16: auto
62 | fp16: false
63 | tf32: true
64 |
65 | early_stopping_patience:
66 | local_rank:
67 | logging_steps: 2
68 | xformers_attention:
69 | flash_attention: true
70 | eval_steps:
71 | eval_table_size:
72 | eval_table_max_new_tokens:
73 | #save_steps: 99999
74 | save_strategy: "epoch"
75 | save_total_limit: 3
76 | debug:
77 |
78 |
79 | ddp: #true
80 | deepspeed:
81 |
82 | fsdp:
83 | fsdp_config:
84 | special_tokens:
85 |
86 |
87 |
--------------------------------------------------------------------------------
/SFT/sft_configs/sft_ds2.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "offload_optimizer": {
5 | "device": "cpu"
6 | },
7 | "contiguous_gradients": true,
8 | "overlap_comm": true
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "auto_cast": false,
16 | "loss_scale": 0,
17 | "initial_scale_power": 32,
18 | "loss_scale_window": 1000,
19 | "hysteresis": 2,
20 | "min_loss_scale": 1
21 | },
22 | "gradient_accumulation_steps": "auto",
23 | "gradient_clipping": "auto",
24 | "train_batch_size": "auto",
25 | "train_micro_batch_size_per_gpu": "auto",
26 | "wall_clock_breakdown": false
27 | }
28 |
--------------------------------------------------------------------------------
/SFT/sft_configs/sft_ds3.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 0,
7 | "reduce_bucket_size": "auto",
8 | "stage3_prefetch_bucket_size": "auto",
9 | "stage3_param_persistence_threshold": "auto",
10 | "stage3_max_live_parameters": 0,
11 | "stage3_max_reuse_distance": 0,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "bf16": {
15 | "enabled": true
16 | },
17 | "fp16": {
18 | "enabled": "auto",
19 | "auto_cast": false,
20 | "loss_scale": 0,
21 | "initial_scale_power": 32,
22 | "loss_scale_window": 1000,
23 | "hysteresis": 2,
24 | "min_loss_scale": 1
25 | },
26 | "gradient_accumulation_steps": "auto",
27 | "gradient_clipping": "auto",
28 | "train_batch_size": "auto",
29 | "train_micro_batch_size_per_gpu": "auto",
30 | "wall_clock_breakdown": false
31 | }
32 |
--------------------------------------------------------------------------------
/alignment_algorithms/README.md:
--------------------------------------------------------------------------------
1 | # DPO/KTO Training with Multi-turn Data
2 |
3 | The implementation of DPO and KTO are adapted from open-source packages [TRL](https://github.com/huggingface/trl) and [RLHFlow](https://github.com/RLHFlow/Online-RLHF). Comparedto the original DPO/KTO, we only need to modify the mask of the samples to mask out all the external tokens. The current implementation supports Gemma and Mistral model. You can read the "get_new_mask" function in dpo_trainer or kto_trainer to get the idea and easily implement for other LLMs.
4 |
5 |
6 | ## 1 Installation instructions
7 |
8 | **Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!!**
9 |
10 |
11 | Before starting, please make sure your linux machine has nvidia-cuda-toolkit installed. See SFT part for the guidance.
12 |
13 |
14 | **Training Environment**
15 |
16 | ```sh
17 | conda create -n alignment_train python=3.10.9
18 | conda activate alignment_train
19 |
20 | git clone https://github.com/huggingface/alignment-handbook.git
21 | cd ./alignment-handbook/
22 | git checkout d17fd7cd3b71c6a7bf7af34d8dc73135bb7ea8e9
23 | pip3 install torch==2.1.2 torchvision torchaudio
24 | python -m pip install .
25 | pip install flash-attn==2.6.3
26 | pip install accelerate==0.33.0
27 |
28 | pip install huggingface-hub==0.24.7
29 | pip install wandb
30 | wandb login
31 | huggingface-cli login
32 | ```
33 |
34 | ## 2 Hakcing the DPO Trainer and KTO Trainer
35 |
36 | ### 2.1 Hack DPO Trainer
37 |
38 | The code is based on RLHFlow/Online-RLHF but we need to hack the trainer to implement some additional functions. We highlight the modified part with ############## MODIFICATION.
39 |
40 | ```sh
41 | # Step 1: find the original DPO trainer
42 | cd anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/
43 |
44 | # Step 2: delete the old one
45 | rm dpo_trainer.py
46 |
47 | # Step 3: use the modified one in this repo. The following command need to be modified to use the correct address
48 | mv dpo_train/dpo_trainer.py anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py
49 | ```
50 |
51 | ### 2.2 Hack KTO Trainer
52 |
53 | The code is based on RLHFlow/Online-RLHF but we need to hack the KTO trainer to implement some additional functions. We highlight the modified part with ############## MODIFICATION.
54 |
55 | ```sh
56 | # Step 1: find the original DPO trainer
57 | cd anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/
58 |
59 | # Step 2: delete the old one
60 | rm kto_trainer.py
61 |
62 | # Step 3: use the modified one in this repo. The following command need to be modified to use the correct address
63 | mv kto_train/kto_trainer.py anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/trainer/kto_trainer.py
64 |
65 | # Step 4: modify the KTO config according to your GPU resource.
66 | vim ./trl/trainer/kto_config.py
67 | max_length: Optional[int] = 2048
68 | max_prompt_length: Optional[int] = 1024
69 | max_completion_length: Optional[int] = 2048
70 | ```
71 |
72 | ### 2.3 Fix Import Error
73 |
74 | For transformers > 4.38.2, you will encounter an import issue related to the following function in anaconda3/envs/alignment_train/lib/python3.10/site-packages/trl/core.py. You can comment on the import from transformers and copy and paste the following hard code version in core.py.
75 |
76 | ```python
77 | def top_k_top_p_filtering(
78 | logits: torch.FloatTensor,
79 | top_k: int = 0
80 | top_p: float = 1.0,
81 | filter_value: float = -float("Inf"),
82 | min_tokens_to_keep: int = 1,
83 | ) -> torch.FloatTensor:
84 | """
85 | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
86 |
87 | Args:
88 | logits: logits distribution shape (batch size, vocabulary size)
89 | top_k (`int`, *optional*, defaults to 0):
90 | If > 0, only keep the top k tokens with highest probability (top-k filtering)
91 | top_p (`float`, *optional*, defaults to 1.0):
92 | If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
93 | filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751)
94 | min_tokens_to_keep (`int`, *optional*, defaults to 1):
95 | Minimumber of tokens we keep per batch example in the output.
96 |
97 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
98 | """
99 |
100 | if top_k > 0:
101 | logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
102 | None, logits
103 | )
104 |
105 | if 0 <= top_p <= 1.0:
106 | logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
107 | None, logits
108 | )
109 |
110 | return logits
111 | ```
112 |
113 |
114 | ## 3 Running the Code
115 |
116 | ### 3.1 DPO
117 | Running the code before modify num_processes: 8 in ./training_configs/zero2_pf.yaml, the number 8 means that you will use 8 GPUs. Also modify the parameters, models, and datasets provided in run_dpo.py.
118 |
119 | ```shell
120 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_dpo.py ./training_configs/training.yaml
121 |
122 | ```
123 |
124 | ### 3.2 KTO
125 |
126 | ```shell
127 | bash run_kto.sh
128 | ```
129 |
130 | If you encounter out-of-memory issue. Running the code with Gemma-7b-it with zero3_pf.yaml. You can also reduce the max length of the data.
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/alignment_algorithms/dpo.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
3 | import copy
4 | import torch
5 | import torch.nn.functional as F
6 | from datasets import Dataset
7 | # from peft import AutoPeftModelForCausalLM, LoraConfig
8 | from torch import nn
9 | from torch.nn.utils.rnn import pad_sequence
10 | from transformers import (
11 | DataCollator,
12 | PreTrainedModel,
13 | PreTrainedTokenizerBase,
14 | TrainerCallback,
15 | TrainingArguments,
16 | )
17 | from transformers.trainer_callback import TrainerCallback
18 | from transformers.trainer_utils import EvalLoopOutput
19 | from trl import DPOTrainer
20 |
21 | # Define and parse arguments.
22 |
23 | ############## MODIFICATION
24 | def get_new_mask(input_ids, old_labels, model='gemma'):
25 | # We mask the user turn to create new labels for Gemma model
26 | labels = copy.deepcopy(old_labels)
27 | start = False
28 | if 'gemma' in model.lower():
29 | for j in range(len(input_ids)):
30 | if input_ids[j:j+3] == [106, 1645, 108]:
31 | start = True
32 | labels[j:j+3] = -100
33 | if input_ids[j:j+2] == [107, 108] and start:
34 | labels[j] = -100
35 | labels[j+1] = -100
36 | start = False
37 | if start:
38 | labels[j] = -100
39 | elif 'mistral' in model.lower():
40 | for j in range(len(input_ids)):
41 | if input_ids[j] == 3:
42 | start = True
43 | input_ids[j] = -100
44 | if input_ids[j] == 4 and start:
45 | labels[j] = -100
46 | start = False
47 | if start:
48 | labels[j] = -100
49 | else:
50 | raise NotImplementedError(model)
51 | return labels
52 | ############## MODIFICATION END
53 |
54 | @dataclass
55 | class PreferenceDataCollatorWithPadding:
56 | tokenizer: PreTrainedTokenizerBase
57 | model: Optional[PreTrainedModel] = None
58 | padding: Union[bool, str] = True
59 | max_length: Optional[int] = None
60 | max_prompt_length: Optional[int] = None
61 | label_pad_token_id: int = -100
62 | padding_value: int = 0
63 | truncation_mode: str = "keep_end"
64 | is_encoder_decoder: Optional[bool] = False
65 | max_target_length: Optional[int] = None
66 | mask_prompt: Optional[bool] = False
67 | mask_user_turn: Optional[bool] = False
68 | model_name: Optional[str] = 'gemma'
69 |
70 | def tokenize_batch_element(
71 | self,
72 | prompt: str,
73 | chosen: str,
74 | rejected: str,
75 | ) -> Dict:
76 | """Tokenize a single batch element.
77 |
78 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
79 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
80 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
81 |
82 | We also create the labels for the chosen/rejected responses, which are of length equal to
83 | the sum of the length of the prompt and the chosen/rejected response, with
84 | label_pad_token_id for the prompt tokens.
85 | """
86 | batch = {}
87 |
88 | if not self.is_encoder_decoder:
89 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
90 | rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
91 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
92 |
93 | eos_token_id = self.tokenizer.eos_token_id
94 | # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
95 | eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
96 | # attention mask these indices to eos_token_id
97 | if self.mask_prompt:
98 | new_attention_mask = [0 for i, p in enumerate(prompt_tokens["attention_mask"])]
99 | print("I mask the prompt")
100 | else:
101 | new_attention_mask = [
102 | 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
103 | ]
104 | prompt_tokens["attention_mask"] = new_attention_mask
105 |
106 | # do the same for chosen and rejected
107 | eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
108 | new_attention_mask_c = [
109 | 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
110 | ]
111 | chosen_tokens["attention_mask"] = new_attention_mask_c
112 |
113 | eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
114 | new_attention_mask_r = [
115 | 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
116 | ]
117 | rejected_tokens["attention_mask"] = new_attention_mask_r
118 |
119 | # add EOS token to end of prompt
120 |
121 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
122 | chosen_tokens["attention_mask"].append(1)
123 |
124 | rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
125 | rejected_tokens["attention_mask"].append(1)
126 |
127 | longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
128 |
129 | # if combined sequence is too long, truncate the prompt
130 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
131 | if self.truncation_mode == "keep_start":
132 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
133 | elif self.truncation_mode == "keep_end":
134 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
135 | else:
136 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
137 |
138 | # if that's still too long, truncate the response
139 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
140 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
141 | rejected_tokens = {
142 | k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
143 | }
144 |
145 | # Create labels
146 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
147 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
148 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
149 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
150 | prompt_tokens["input_ids"]
151 | )
152 | rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
153 | rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
154 | prompt_tokens["input_ids"]
155 | )
156 | ############ MODIFICATION
157 | if self.mask_user_turn:
158 | new_chosen_sequence_labels= get_new_mask(chosen_sequence_tokens['input_ids'], chosen_sequence_tokens['labels'], model=self.model_name)
159 | new_rej_sequence_labels = get_new_mask(rejected_sequence_tokens['input_ids'], rejected_sequence_tokens['labels'], model=self.model_name)
160 | chosen_sequence_tokens["labels"] = new_chosen_sequence_labels
161 | rejected_sequence_tokens["labels"] = new_rej_sequence_labels
162 | ############
163 | for k, toks in {
164 | "chosen": chosen_sequence_tokens,
165 | "rejected": rejected_sequence_tokens,
166 | "prompt": prompt_tokens,
167 | }.items():
168 | for type_key, tokens in toks.items():
169 | if type_key == "token_type_ids":
170 | continue
171 | batch[f"{k}_{type_key}"] = tokens
172 |
173 | else:
174 | raise NotImplementedError
175 |
176 | batch["prompt"] = prompt
177 | batch["chosen"] = prompt + chosen
178 | batch["rejected"] = prompt + rejected
179 | batch["chosen_response_only"] = chosen
180 | batch["rejected_response_only"] = rejected
181 | return batch
182 |
183 | def collate(self, batch):
184 | # first, pad everything to the same length
185 | padded_batch = {}
186 | for k in batch[0].keys():
187 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
188 | if self.is_encoder_decoder:
189 | to_pad = [torch.LongTensor(ex[k]) for ex in batch]
190 |
191 | if (k.startswith("prompt")) and (k.endswith("input_ids")):
192 | padding_value = self.tokenizer.pad_token_id
193 | elif k.endswith("_attention_mask"):
194 | padding_value = 0
195 | elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
196 | padding_value = self.label_pad_token_id
197 | else:
198 | raise ValueError(f"Unexpected key in batch '{k}'")
199 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
200 | else:
201 | # adapted from https://stackoverflow.com/questions/73256206
202 | if "prompt" in k:
203 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
204 | else:
205 | to_pad = [torch.LongTensor(ex[k]) for ex in batch]
206 | if k.endswith("_input_ids"):
207 | padding_value = self.tokenizer.pad_token_id
208 | elif k.endswith("_labels"):
209 | padding_value = self.label_pad_token_id
210 | elif k.endswith("_attention_mask"):
211 | padding_value = self.padding_value
212 | else:
213 | raise ValueError(f"Unexpected key in batch '{k}'")
214 |
215 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
216 | # for the prompt, flip back so padding is on left side
217 | if "prompt" in k:
218 | padded_batch[k] = padded_batch[k].flip(dims=[1])
219 | else:
220 | padded_batch[k] = [ex[k] for ex in batch]
221 |
222 | return padded_batch
223 |
224 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
225 | tokenized_batch = []
226 |
227 | for feature in features:
228 | prompt = feature["prompt"]
229 | chosen = feature["chosen"]
230 | rejected = feature["rejected"]
231 |
232 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
233 | batch_element["margin"] = feature["margin"]
234 | tokenized_batch.append(batch_element)
235 |
236 | # return collated batch
237 | return self.collate(tokenized_batch)
238 |
239 |
240 | class PreferenceTrainer(DPOTrainer):
241 | def __init__(
242 | self,
243 | model: Union[PreTrainedModel, nn.Module] = None,
244 | ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
245 | beta: float = 0.1,
246 | loss_type: Literal["sigmoid", "hinge", "cross_entropy", "kl", "rev_kl", "raft"] = "rev_kl",
247 | args: TrainingArguments = None,
248 | data_collator: Optional[DataCollator] = None,
249 | label_pad_token_id: int = -100,
250 | padding_value: int = 0,
251 | truncation_mode: str = "keep_end",
252 | train_dataset: Optional[Dataset] = None,
253 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
254 | tokenizer: Optional[PreTrainedTokenizerBase] = None,
255 | model_init: Optional[Callable[[], PreTrainedModel]] = None,
256 | callbacks: Optional[List[TrainerCallback]] = None,
257 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
258 | None,
259 | None,
260 | ),
261 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
262 | max_length: Optional[int] = None,
263 | max_prompt_length: Optional[int] = None,
264 | max_target_length: Optional[int] = None,
265 | peft_config: Optional[Dict] = None,
266 | is_encoder_decoder: Optional[bool] = None,
267 | disable_dropout: bool = True,
268 | generate_during_eval: bool = False,
269 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
270 | mask_prompt: Optional[bool] = False,
271 | len_penalty: float = 0,
272 | nll_coefficient: float = 0,
273 | masking_user_turn: bool = True,
274 | ):
275 | ################# MODIFICATION
276 | # if nll_coefficient > 0, then dpo_loss + self.nll_coefficient * nll_loss on preferred response
277 | self.nll_coefficient = nll_coefficient
278 | # whether we mask the user turn or not for implementing m-dpo, if not, it is a regular single-turn DPO
279 | self.masking_user_turn = masking_user_turn
280 | ############## MODIFICATION END
281 |
282 |
283 | if data_collator is None:
284 | data_collator = PreferenceDataCollatorWithPadding(
285 | tokenizer,
286 | max_length=max_length,
287 | max_prompt_length=max_prompt_length,
288 | label_pad_token_id=label_pad_token_id,
289 | padding_value=padding_value,
290 | truncation_mode=truncation_mode,
291 | is_encoder_decoder=False,
292 | max_target_length=max_target_length,
293 | mask_prompt=mask_prompt,
294 | mask_user_turn = self.masking_user_turn,
295 | )
296 | super().__init__(
297 | model=model,
298 | ref_model=ref_model,
299 | beta=beta,
300 | loss_type=loss_type,
301 | args=args,
302 | data_collator=data_collator,
303 | label_pad_token_id=label_pad_token_id,
304 | padding_value=padding_value,
305 | truncation_mode=truncation_mode,
306 | train_dataset=train_dataset,
307 | eval_dataset=eval_dataset,
308 | tokenizer=tokenizer,
309 | model_init=model_init,
310 | callbacks=callbacks,
311 | optimizers=optimizers,
312 | preprocess_logits_for_metrics=preprocess_logits_for_metrics,
313 | max_length=max_length,
314 | max_prompt_length=max_prompt_length,
315 | max_target_length=max_target_length,
316 | peft_config=peft_config,
317 | is_encoder_decoder=is_encoder_decoder,
318 | disable_dropout=disable_dropout,
319 | generate_during_eval=generate_during_eval,
320 | compute_metrics=compute_metrics,
321 | )
322 | self.use_dpo_data_collator = True
323 | self.len_penalty = len_penalty
324 |
325 | def dpo_loss(
326 | self,
327 | policy_chosen_logps: torch.FloatTensor,
328 | policy_rejected_logps: torch.FloatTensor,
329 | reference_chosen_logps: torch.FloatTensor,
330 | reference_rejected_logps: torch.FloatTensor,
331 | reference_free: bool = False,
332 | margin: Optional[torch.FloatTensor] = None,
333 | len_penalty: float = 0,
334 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
335 | """Compute the DPO loss for a batch of policy and reference model log probabilities.
336 |
337 | Args:
338 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
339 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
340 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
341 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
342 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
343 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
344 |
345 | Returns:
346 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
347 | The losses tensor contains the DPO loss for each example in the batch.
348 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
349 | """
350 | pi_logratios = policy_chosen_logps - policy_rejected_logps
351 | ref_logratios = reference_chosen_logps - reference_rejected_logps + len_penalty
352 |
353 | if reference_free:
354 | ref_logratios = 0
355 |
356 | if self.loss_type == "sigmoid":
357 | logits = pi_logratios - ref_logratios
358 | losses = -F.logsigmoid(self.beta * logits)
359 | elif self.loss_type == "hinge":
360 | logits = pi_logratios - ref_logratios
361 | losses = torch.relu(1 - self.beta * logits)
362 | elif self.loss_type == "cross_entropy":
363 | logits = policy_chosen_logps - reference_chosen_logps
364 | losses = -F.logsigmoid(self.beta * logits)
365 | elif self.loss_type == "raft":
366 | losses = -policy_chosen_logps # F.logsigmoid(self.beta * logits)
367 | elif self.loss_type == "ipo":
368 | logits = pi_logratios - ref_logratios
369 | # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
370 | losses = (logits - 1 / (2 * self.beta)) ** 2
371 | elif self.loss_type == "kl":
372 | logits = pi_logratios - ref_logratios
373 | p = F.sigmoid(self.beta * logits)
374 | p = torch.minimum(p, torch.ones_like(p) * 0.999)
375 | p_gt = torch.exp(margin) / (1 + torch.exp(margin) + 1e-3)
376 | losses = p * (torch.log(p) - torch.log(p_gt)) + (1 - p) * (torch.log(1 - p) - torch.log(1 - p_gt))
377 | elif self.loss_type == "tv":
378 | logits = pi_logratios - ref_logratios
379 | p = F.sigmoid(self.beta * logits)
380 | p_gt = torch.exp(margin) / (1 + torch.exp(margin))
381 | losses = torch.abs(p - p_gt)
382 | elif self.loss_type == "hellinger":
383 | logits = pi_logratios - ref_logratios
384 | p = F.sigmoid(self.beta * logits)
385 | p = torch.minimum(p, torch.ones_like(p) * 0.999)
386 | p_gt = torch.exp(margin) / (1 + torch.exp(margin))
387 | losses = 0.5 * ((p**0.5 - p_gt**0.5) ** 2 + ((1 - p) ** 0.5 - (1 - p_gt) ** 0.5) ** 2)
388 | elif self.loss_type == "rev_kl":
389 | logits = pi_logratios - ref_logratios
390 | logp = F.logsigmoid(self.beta * logits)
391 | logp_neg = F.logsigmoid(-self.beta * logits)
392 | p_gt = F.sigmoid(margin)
393 | losses = -p_gt * (logp) - (1 - p_gt) * logp_neg
394 | else:
395 | raise ValueError(f"Unknown loss type: {self.loss_type}.")
396 |
397 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
398 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
399 |
400 | return losses, chosen_rewards, rejected_rewards
401 |
402 | def get_batch_loss_metrics(
403 | self,
404 | model,
405 | batch: Dict[str, Union[List, torch.LongTensor]],
406 | train_eval: Literal["train", "eval"] = "train",
407 | ):
408 | return self.get_batch_metrics(model, batch, train_eval)
409 |
410 | def get_batch_metrics(
411 | self,
412 | model,
413 | batch: Dict[str, Union[List, torch.LongTensor]],
414 | train_eval: Literal["train", "eval"] = "train",
415 | ):
416 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
417 | metrics = {}
418 | (
419 | policy_chosen_logps,
420 | policy_rejected_logps,
421 | policy_chosen_logits,
422 | policy_rejected_logits,
423 | policy_nll_loss,
424 | ) = self.concatenated_forward(model, batch)
425 | with torch.no_grad():
426 | if self.ref_model is None:
427 | with self.accelerator.unwrap_model(self.model).disable_adapter():
428 | (
429 | reference_chosen_logps,
430 | reference_rejected_logps,
431 | _,
432 | _,
433 | _,
434 | ) = self.concatenated_forward(self.model, batch)
435 | else:
436 | (
437 | reference_chosen_logps,
438 | reference_rejected_logps,
439 | _,
440 | _,
441 | _,
442 | ) = self.concatenated_forward(self.ref_model, batch)
443 | if self.len_penalty > 0:
444 | chosen_len = batch["chosen_input_ids"].shape[1] * self.len_penalty
445 | rejected_len = batch["rejected_input_ids"].shape[1] * self.len_penalty
446 | len_penalty = chosen_len - rejected_len
447 | else:
448 | chosen_len = 1
449 | rejected_len = 1
450 | len_penalty = 0
451 |
452 | margin = torch.tensor(batch["margin"], dtype=policy_chosen_logps.dtype).to(self.accelerator.device)
453 | losses, chosen_rewards, rejected_rewards = self.dpo_loss(
454 | policy_chosen_logps,
455 | policy_rejected_logps,
456 | reference_chosen_logps,
457 | reference_rejected_logps,
458 | margin=margin,
459 | len_penalty=len_penalty,
460 | )
461 | ############## MODIFICATION
462 | if self.nll_coefficient > 0:
463 | losses = losses + self.nll_coefficient * policy_nll_loss
464 | else:
465 | pass
466 | ############## MODIFICATION END
467 | reward_accuracies = (chosen_rewards > rejected_rewards).float()
468 |
469 | prefix = "eval_" if train_eval == "eval" else ""
470 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
471 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
472 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
473 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
474 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
475 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
476 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
477 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
478 | ########
479 | metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().cpu().mean()
480 | ########
481 | return losses.mean(), metrics
482 |
--------------------------------------------------------------------------------
/alignment_algorithms/run_dpo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 |
5 | import torch
6 | from datasets import Dataset, load_dataset
7 | from dpo import PreferenceTrainer
8 | from transformers import (
9 | AutoModelForCausalLM,
10 | AutoTokenizer,
11 | HfArgumentParser,
12 | TrainingArguments,
13 | )
14 |
15 |
16 | @dataclass
17 | class ScriptArguments:
18 | """
19 | The arguments for the DPO training script.
20 | """
21 |
22 | # data parameters, i.e., the KL penalty in the paper
23 | beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
24 |
25 | # training parameters
26 | model_name_or_path: Optional[str] = field(
27 | default="RLHF4MATH/Gemma-7B-it-SFT3epoch",
28 | metadata={"help": "the location of the model name or path"},
29 | )
30 | ref_model: Optional[str] = field(
31 | default="RLHF4MATH/Gemma-7B-it-SFT3epoch",
32 | metadata={"help": "the location of the SFT model name or path"},
33 | )
34 | train_dir: Optional[str] = field(
35 | default="RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs",
36 | metadata={"help": "the location of the dataset name or path"},
37 | )
38 | eval_dir: Optional[str] = field(
39 | default="RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs",
40 | metadata={"help": "the location of the evalset name or path"},
41 | )
42 | learning_rate: Optional[float] = field(default=4e-7, metadata={"help": "optimizer learning rate"})
43 | lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
44 | warmup_steps: Optional[int] = field(default=50, metadata={"help": "the number of warmup steps"})
45 | weight_decay: Optional[float] = field(default=0.01, metadata={"help": "the weight decay"})
46 | optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
47 |
48 | per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"})
49 | per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
50 | gradient_accumulation_steps: Optional[int] = field(
51 | default=4, metadata={"help": "the number of gradient accumulation steps"}
52 | )
53 | gradient_checkpointing: Optional[bool] = field(
54 | default=True, metadata={"help": "whether to use gradient checkpointing"}
55 | )
56 |
57 | eos_padding: Optional[bool] = field(default=True, metadata={"help": "whether to pad with eos token"})
58 | lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
59 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
60 | lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
61 |
62 | margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"})
63 |
64 | max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"})
65 | max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length"})
66 | max_steps: Optional[int] = field(default=4000, metadata={"help": "max number of training steps"})
67 | num_train_epochs: Optional[int] = field(default=2, metadata={"help": "max number of training epochs"})
68 | logging_steps: Optional[int] = field(default=2, metadata={"help": "the logging frequency"})
69 | save_strategy: Optional[str] = field(default="steps", metadata={"help": "the saving strategy"})
70 | save_steps: Optional[int] = field(default=25, metadata={"help": "the saving frequency"})
71 | eval_steps: Optional[int] = field(default=300, metadata={"help": "the evaluation frequency"})
72 | run_name: Optional[str] = field(default="mdpo_iter1_gemma7b_lr4e7_bz32", metadata={"help": "the run name"})
73 | loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"})
74 | output_dir: Optional[str] = field(
75 | default="./mdpo_iter1_gemma7b_lr4e7_bz32", metadata={"help": "the output directory"}
76 | )
77 | log_freq: Optional[int] = field(default=2, metadata={"help": "the logging frequency"})
78 |
79 | # instrumentation
80 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
81 |
82 | max_training_samples: Optional[int] = field(default=-1, metadata={"help": "the maximum sample size"})
83 |
84 | choose_type: Optional[str] = field(default="max_random", metadata={"help": "the choose type"})
85 |
86 | report_to: Optional[str] = field(
87 | default="wandb",
88 | metadata={
89 | "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
90 | '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
91 | 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
92 | },
93 | )
94 | # debug argument for distributed training
95 | ignore_bias_buffers: Optional[bool] = field(
96 | default=False,
97 | metadata={
98 | "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
99 | "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
100 | },
101 | )
102 | eot_token: Optional[str] = field(default="", metadata={"help": "the end of text token"})
103 | mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"})
104 | len_penalty: Optional[float] = field(default=0, metadata={"help": "the length penalty"})
105 |
106 | masking_user_turn: Optional[bool] = field(default=True, metadata={"help": "mask user turn"})
107 | nll_coefficient: Optional[float] = field(default=0, metadata={"help": "the coefficeint of NLL loss"})
108 |
109 |
110 | def prepare_data(
111 | tokenizer,
112 | data_dir: str = "xxx",
113 | sanity_check: bool = False,
114 | cache_dir: str = None,
115 | num_proc=24,
116 | margin_scale=1,
117 | choose_type="random",
118 | eot_token="",
119 | length_penalty=0,
120 | ) -> Dataset:
121 | """Prepare the dataset for DPO training. The input datasets are supposed to be in the standard format with keys chosen and rejected.
122 | The margin is not used currently and may be activated later for future research.
123 |
124 | [ { "content": "If a 40-foot tree is casting a 10-foot shadow, and Andrea is casting a 15-inch shadow at the same time, how tall is Andrea in inches?", "role": "user" },
125 | { "content": "The shadow of the tree is 10 feet which is 120 inches.\nSo let's set *tree height* = 40 feet = 40 * 12 inches\n*tree shadow* = 10 feet = 120 inches\n*Andrea's shadow* = 15 inches\nFrom the similar triangles, we can find Andrea's height.\n```python\ntree_height = 40 * 12 # tree is 40 feet which is 40 * 12 inches\ntree_shadow = 10 * 12 # tree shadow is 10 feet = 120 inches\nandrea_shadow = 15 # Andrea's shadow is 15 inches\n\n# Find Andrea's height using similar triangles\nandrea_height = andrea_shadow * (tree_height / tree_shadow)\nandrea_height\n```", "role": "assistant" },
126 | { "content": "```output\n60.0\n```", "role": "user" },
127 | { "content": "So Andrea is $\\boxed{60}$ inches tall.", "role": "assistant" } ]
128 | """
129 | ds = load_dataset(data_dir, split="train")
130 | ds = ds.shuffle(seed=42)
131 | print(ds)
132 |
133 | pos = []
134 | neg = []
135 | prompts = []
136 | margin = []
137 | for sample in ds:
138 | chosen = sample["chosen"]
139 | rejected = sample["rejected"]
140 | prompt = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True)
141 | prompt2 = tokenizer.apply_chat_template([rejected[0]], tokenize=False, add_generation_prompt=True)
142 | if prompt != prompt2:
143 | continue
144 |
145 | # assert prompt == prompt2
146 | chosen_str = tokenizer.apply_chat_template(chosen, tokenize=False).replace(prompt, "")
147 | rejected_str = tokenizer.apply_chat_template(rejected, tokenize=False).replace(prompt, "")
148 | prompts.append(prompt)
149 | pos.append(chosen_str)
150 | neg.append(rejected_str)
151 | margin.append(0.5) # not used so far
152 | dataset = Dataset.from_dict({"prompt": prompts, "chosen": pos, "rejected": neg, "margin": margin})
153 | if sanity_check:
154 | dataset = dataset.select(range(min(len(dataset), 100)))
155 |
156 | return dataset
157 |
158 |
159 | if __name__ == "__main__":
160 | parser = HfArgumentParser(ScriptArguments)
161 | script_args = parser.parse_args_into_dataclasses()[0]
162 |
163 | # 1. load a pretrained model
164 | model = AutoModelForCausalLM.from_pretrained(
165 | script_args.model_name_or_path,
166 | use_flash_attention_2=True,
167 | torch_dtype=torch.float16,
168 | )
169 | model.config.use_cache = False
170 |
171 | if script_args.ignore_bias_buffers:
172 | # torch distributed hack
173 | model._ddp_params_and_buffers_to_ignore = [
174 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
175 | ]
176 |
177 | if script_args.ref_model:
178 | ref_name = script_args.ref_model
179 | else:
180 | ref_name = script_args.model_name_or_path
181 |
182 | model_ref = AutoModelForCausalLM.from_pretrained(
183 | ref_name,
184 | torch_dtype=torch.bfloat16,
185 | use_flash_attention_2=True,
186 | )
187 | tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
188 |
189 | # 2. Load the paired dataset
190 | train_dataset = prepare_data(
191 | tokenizer,
192 | data_dir=script_args.train_dir,
193 | margin_scale=script_args.margin_scale,
194 | sanity_check=script_args.sanity_check,
195 | choose_type=script_args.choose_type,
196 | eot_token=script_args.eot_token,
197 | length_penalty=script_args.len_penalty,
198 | )
199 | print(train_dataset)
200 | print(train_dataset[0])
201 | if script_args.max_training_samples > 0:
202 | train_dataset = train_dataset.select(range(script_args.max_training_samples))
203 |
204 | # 3. Load evaluation dataset
205 | eval_dataset = prepare_data(
206 | tokenizer,
207 | data_dir=script_args.eval_dir,
208 | sanity_check=True,
209 | margin_scale=script_args.margin_scale,
210 | eot_token=script_args.eot_token,
211 | )
212 |
213 | # 4. initialize training arguments:
214 |
215 | training_args = TrainingArguments(
216 | per_device_train_batch_size=script_args.per_device_train_batch_size,
217 | per_device_eval_batch_size=script_args.per_device_eval_batch_size,
218 | # max_steps=script_args.max_steps,
219 | num_train_epochs=script_args.num_train_epochs,
220 | save_strategy=script_args.save_strategy,
221 | logging_steps=script_args.logging_steps,
222 | save_steps=script_args.save_steps,
223 | gradient_accumulation_steps=script_args.gradient_accumulation_steps,
224 | gradient_checkpointing=script_args.gradient_checkpointing,
225 | learning_rate=script_args.learning_rate,
226 | evaluation_strategy="steps",
227 | eval_steps=script_args.eval_steps,
228 | output_dir=script_args.output_dir,
229 | # report_to=script_args.report_to,
230 | lr_scheduler_type=script_args.lr_scheduler_type,
231 | warmup_steps=script_args.warmup_steps,
232 | # optim=script_args.optimizer_type,
233 | bf16=True,
234 | remove_unused_columns=False,
235 | run_name=script_args.run_name,
236 | save_only_model=True,
237 | )
238 | print(training_args)
239 |
240 | # 5. initialize the DPO trainer
241 |
242 | dpo_trainer = PreferenceTrainer(
243 | model,
244 | model_ref,
245 | args=training_args,
246 | beta=script_args.beta,
247 | train_dataset=train_dataset,
248 | eval_dataset=eval_dataset,
249 | tokenizer=tokenizer,
250 | loss_type=script_args.loss_type,
251 | max_prompt_length=script_args.max_prompt_length,
252 | max_length=script_args.max_length,
253 | mask_prompt=script_args.mask_prompt,
254 | len_penalty=script_args.len_penalty,
255 | nll_coefficient=script_args.nll_coefficient,
256 | masking_user_turn=script_args.masking_user_turn,
257 | )
258 | print("begin to train")
259 |
260 | # 6. train
261 | dpo_trainer.train()
262 | dpo_trainer.save_model(script_args.output_dir)
263 |
264 | # 7. save
265 | output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
266 | dpo_trainer.model.save_pretrained(output_dir)
267 |
--------------------------------------------------------------------------------
/alignment_algorithms/run_dpo.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_dpo.py \
2 | --model_name_or_path RLHF4MATH/Gemma-7B-it-SFT3epoch \
3 | --ref_model RLHF4MATH/Gemma-7B-it-SFT3epoch \
4 | --per_device_train_batch_size 1 \
5 | --num_train_epochs 1 \
6 | --train_dir RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \
7 | --eval_dir RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \
8 | --learning_rate 2e-7 \
9 | --lr_scheduler_type=cosine \
10 | --gradient_accumulation_steps 4 \
11 | --logging_steps 2 \
12 | --eval_steps 10000 \
13 | --output_dir=./mdpo_iter1_gemma7b_lr2e7_bz32 \
14 | --warmup_ratio 0.1 \
15 | --report_to wandb \
16 | --bf16 \
17 | --save_strategy=steps \
18 | --save_steps=50 \
19 |
--------------------------------------------------------------------------------
/alignment_algorithms/run_kto.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
17 |
18 | # Full training:
19 | python examples/scripts/kto.py \
20 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
21 | --per_device_train_batch_size 16 \
22 | --num_train_epochs 1 \
23 | --learning_rate 1e-5 \
24 | --lr_scheduler_type=cosine \
25 | --gradient_accumulation_steps 1 \
26 | --logging_steps 10 \
27 | --eval_steps 500 \
28 | --output_dir=kto-aligned-model \
29 | --warmup_ratio 0.1 \
30 | --report_to wandb \
31 | --bf16 \
32 | --logging_first_step
33 |
34 | # QLoRA:
35 | python examples/scripts/kto.py \
36 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
37 | --per_device_train_batch_size 8 \
38 | --num_train_epochs 1 \
39 | --learning_rate 1e-4 \
40 | --lr_scheduler_type=cosine \
41 | --gradient_accumulation_steps 1 \
42 | --logging_steps 10 \
43 | --eval_steps 500 \
44 | --output_dir=kto-aligned-model-lora \
45 | --warmup_ratio 0.1 \
46 | --report_to wandb \
47 | --bf16 \
48 | --logging_first_step \
49 | --use_peft \
50 | --load_in_4bit \
51 | --lora_target_modules=all-linear \
52 | --lora_r=16 \
53 | --lora_alpha=16
54 | """
55 |
56 | from dataclasses import dataclass
57 |
58 | from datasets import load_dataset
59 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
60 |
61 | from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
62 |
63 |
64 | # Define and parse arguments.
65 | @dataclass
66 | class ScriptArguments:
67 | """
68 | The arguments for the KTO training script.
69 | """
70 |
71 | dataset_name: str = "trl-lib/kto-mix-14k"
72 |
73 |
74 | if __name__ == "__main__":
75 | parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
76 | script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
77 |
78 | # Load a pretrained model
79 | model = AutoModelForCausalLM.from_pretrained(
80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
81 | )
82 | ref_model = AutoModelForCausalLM.from_pretrained(
83 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
84 | )
85 |
86 | tokenizer = AutoTokenizer.from_pretrained(
87 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
88 | )
89 | if tokenizer.pad_token is None:
90 | tokenizer.pad_token = tokenizer.eos_token
91 |
92 | # If we are aligning a base model, we use ChatML as the default template
93 | if tokenizer.chat_template is None:
94 | #model, tokenizer = setup_chat_format(model, tokenizer)
95 | print("No chat template is provided. Please set up the mode appropriately first.")
96 |
97 | # Load the dataset
98 | from datasets import Dataset, concatenate_datasets
99 |
100 | def get_chosen(example):
101 | chosen = example['chosen']
102 | example['prompt'] = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True)
103 | example['completion'] = tokenizer.apply_chat_template(chosen, tokenize=False).replace(example['prompt'], "")
104 | example['label'] = True
105 | return example
106 |
107 | def get_rej(example):
108 | chosen = example['rejected']
109 | example['prompt'] = tokenizer.apply_chat_template([chosen[0]], tokenize=False, add_generation_prompt=True)
110 | example['completion'] = tokenizer.apply_chat_template(chosen, tokenize=False).replace(example['prompt'], "")
111 | example['label'] = False
112 | return example
113 |
114 | ds = load_dataset(script_args.dataset_name, split='train')
115 | ds1 = ds.map(get_chosen, num_proc=32)
116 | ds2 = ds.map(get_rej, num_proc =32)
117 |
118 | dataset = concatenate_datasets([ds1, ds2])
119 | dataset = dataset.shuffle(seed=42)
120 |
121 | # Initialize the KTO trainer
122 | kto_trainer = KTOTrainer(
123 | model,
124 | ref_model,
125 | args=kto_args,
126 | train_dataset=dataset,
127 | eval_dataset=dataset,
128 | tokenizer=tokenizer,
129 | peft_config=get_peft_config(model_args),
130 | )
131 |
132 | # Train and push the model to the Hub
133 | kto_trainer.train()
134 | kto_trainer.save_model(kto_args.output_dir)
135 | kto_trainer.push_to_hub()
136 |
--------------------------------------------------------------------------------
/alignment_algorithms/run_kto.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file ./training_configs/zero2_pf.yaml run_kto.py \
2 | --model_name_or_path RLHF4MATH/Gemma-7B-it-SFT3epoch \
3 | --per_device_train_batch_size 1 \
4 | --num_train_epochs 1 \
5 | --learning_rate 2e-7 \
6 | --lr_scheduler_type=cosine \
7 | --gradient_accumulation_steps 4 \
8 | --logging_steps 2 \
9 | --eval_steps 1000 \
10 | --output_dir=./mkto_iter1_gemma7b_lr2e7_bz32 \
11 | --warmup_ratio 0.1 \
12 | --report_to wandb \
13 | --bf16 \
14 | --logging_first_step \
15 | --dataset_name=RLHF4MATH/Gemma-7B-1.1-it-iter1-random-pairs \
16 | --save_strategy=steps \
17 | --save_steps=50 \
18 | --save_only_model=True \
19 |
--------------------------------------------------------------------------------
/alignment_algorithms/training_configs/zero2_pf.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | offload_optimizer_device: none
5 | offload_param_device: none
6 | zero3_init_flag: false
7 | zero_stage: 2
8 | distributed_type: DEEPSPEED
9 | downcast_bf16: 'no'
10 | machine_rank: 0
11 | main_training_function: main
12 | mixed_precision: bf16
13 | num_machines: 1
14 | num_processes: 8
15 | rdzv_backend: static
16 | same_network: true
17 | tpu_env: []
18 | tpu_use_cluster: false
19 | tpu_use_sudo: false
20 | use_cpu: false
21 |
--------------------------------------------------------------------------------
/alignment_algorithms/training_configs/zero3_pf.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/assets/main_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/assets/main_result.png
--------------------------------------------------------------------------------
/inference/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WeiXiongUST/Building-Math-Agents-with-Multi-Turn-Iterative-Preference-Learning/ae1e6422eac4ad04ef5ab1c2bbffb4b4497c70ab/inference/.DS_Store
--------------------------------------------------------------------------------
/inference/README.md:
--------------------------------------------------------------------------------
1 | ## Inference for MATH Problem with External Python Intepreter
2 |
3 | In this repo, we implement the MATH problem solving with external python tool using the VLLM to accelerate infernece. The codebase is largely based on the Tora project.
4 |
5 |
6 | ## 1 Installation instructions
7 |
8 | **Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!!**
9 |
10 |
11 | Before starting, please make sure your linux machine has nvidia-cuda-toolkit installed. See SFT part for the guidance.
12 |
13 |
14 | ```sh
15 | conda create -n infer python=3.10.9
16 | conda activate infer
17 | # You may check nvcc -V , if no nvcc exists, you may run the following code
18 | # conda install nvidia/label/cuda-12.2.0::cuda-nvcc
19 |
20 | pip install datasets
21 |
22 | # The following code is tested for CUDA12.0-12.2 and CUDA12.6
23 | # To develop llama-3, mistral, gemma-1, 1.1, 2, deepseek you can consider the following vllm version
24 | pip install vllm==0.5.4
25 |
26 | pip install accelerate==0.33.0
27 | pip install deepspeed==0.14.5
28 | pip install numpy==1.26.4 #Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!!
29 |
30 |
31 | huggingface-cli login
32 | pip install pebble
33 | pip install timeout_decorator
34 | pip install ipython
35 | pip install sympy==1.12
36 | pip install antlr4-python3-runtime==4.11 # The versions of sympy and antlr4 cannot be modified!!!!!
37 | ```
38 |
39 | ## 2 The General Process of Inference
40 |
41 | The current codes are implemented specially for Gemma, Mistral, deepseek, and Llama3 (mainly in terms of the prompt format). You should include gemma, mistral, deepseek, or llama in the model name so that the code can specify the prompt format.
42 |
43 |
44 | **Step 1** To start with, we prepare a prompt (problem) into the following Gemma format:
45 |
46 | ```python
47 | prompt = "user\nEvaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$.\nmodel"
48 | ```
49 |
50 | **Step 2** The model received the prompt and generate some reasoning step and/or write some code.
51 |
52 | ```python
53 | response = "To evaluate the expression, we will use Python's sympy library.\npython\\nfrom sympy import ceiling, Rational\\n\\n# Evaluate the expression\\nexpression_result = ceiling(3 * (6 - Rational(1, 2)))\\n\\nexpression_result\\n"
54 | ```
55 |
56 | After getting the response, we first detect whether the model outputs the final results by \\boxed{FINAL RESULT}. If not, we continue to detect whether the model outputs a python code needs to be executed by "\`\`\`python SOME CODE \`\`\`". If yes, we run the code and get the result.
57 |
58 | In this example, the execution result is "17". Therefore, we obtain the new prompt by adding the observation into the history:
59 |
60 |
61 | ```python
62 | prompt2 = "user\nEvaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$.\nmodel\nTo evaluate the expression, we will use Python's sympy library.\npython\\nfrom sympy import ceiling, Rational\\n\\n# Evaluate the expression\\nexpression_result = ceiling(3 * (6 - Rational(1, 2)))\\n\\nexpression_result\\n\nuser\noutput\\n17\\n\nmodel\n"
63 | ```
64 |
65 | A new step begins and we stop either the model outputs the final answer or reaches the maximal number of tool calls.
66 |
67 | ## 3 Running the Generation Code
68 |
69 | To run the generation, the first step is to register the model as a server so that we can query it. You should first modify the scripts/register_server.sh according to your GPU setup. Then, run the following command with your model.
70 |
71 | ```sh
72 | bash scripts/register_server.sh google/gemma-1.1-7b-it
73 | ```
74 |
75 | Then, you can run the scripts/infer_eval.sh, which will generate trajectories with the model and evaluate the output.
76 |
77 | ```sh
78 | bash scripts/infer_eval.sh gemma_7b
79 | ```
80 |
81 | The iter_infer_to_collect_data.sh additionally takes a for loop to iteratively generate trajectories. The model name should contain mistral, gemma, llama, or deepseek so that the code can specify the prompt format.
82 |
83 |
84 | ## 4 Annotate Data
85 |
86 | ```sh
87 | python -um infer_data.annotate_data --data_name gsm8k --prompt_type tora --file_path ./collect_data/gemma_7b/gsm8k/train_tora_7473_seed1_t0.0_s0_e7473_09-22_16-22.jsonl --output_dir test_output.jsonl
88 | ```
89 |
--------------------------------------------------------------------------------
/inference/eval/evaluate.py:
--------------------------------------------------------------------------------
1 | # The evaluator is adapted from the ToRA project
2 | # https://github.com/microsoft/ToRA
3 | # ToRA authors: Zhibin Gou and Zhihong Shao and Yeyun Gong and yelong shen and Yujiu Yang and Minlie Huang and Nan Duan and Weizhu Chen
4 |
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pebble import ProcessPool
9 | from concurrent.futures import TimeoutError
10 |
11 | from eval.grader import *
12 | from utils.parser import *
13 | from utils.utils import load_jsonl
14 | from utils.python_executor import PythonExecutor
15 |
16 |
17 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False):
18 | assert samples or file_path, "samples or file_path must be provided"
19 | if not samples:
20 | samples = list(load_jsonl(file_path))
21 | # dedup by idx
22 | if 'idx' in samples[0]:
23 | samples = {sample['idx']: sample for sample in samples}.values()
24 | samples = sorted(samples, key=lambda x: x['idx'])
25 | else:
26 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
27 |
28 | if max_num_samples:
29 | print(f"max_num_samples: {max_num_samples} / {len(samples)}")
30 | samples = samples[:max_num_samples]
31 |
32 | # parse gt
33 | for sample in samples:
34 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
35 |
36 | # execute
37 | if ('pred' not in samples[0]) or execute:
38 | if "pal" in prompt_type:
39 | executor = PythonExecutor(get_answer_expr="solution()")
40 | else:
41 | executor = PythonExecutor(get_answer_from_stdout=True)
42 |
43 | for sample in tqdm(samples, desc="Execute"):
44 | sample['pred'] = []
45 | sample['report'] = []
46 | for code in sample['code']:
47 | pred, report = run_execute(executor, code, prompt_type, execute=True)
48 | sample['pred'].append(pred)
49 | sample['report'].append(report)
50 |
51 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]
52 |
53 | scores = []
54 | timeout_cnt = 0
55 |
56 | with ProcessPool() as pool:
57 | future = pool.map(math_equal_process, params, timeout=10)
58 | iterator = future.result()
59 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
60 | while True:
61 | try:
62 | result = next(iterator)
63 | scores.append(result)
64 | except StopIteration:
65 | break
66 | except TimeoutError as error:
67 | print(error)
68 | scores.append(False)
69 | timeout_cnt += 1
70 | except Exception as error:
71 | print(error.traceback)
72 | exit()
73 | progress_bar.update(1)
74 |
75 | idx = 0
76 | score_mat = []
77 | for sample in samples:
78 | sample['score'] = scores[idx: idx+len(sample['pred'])]
79 | assert len(sample['score']) == len(sample['pred'])
80 | score_mat.append(sample['score'])
81 | idx += len(sample['pred'])
82 |
83 | max_len = max([len(s) for s in score_mat])
84 |
85 | for i, s in enumerate(score_mat):
86 | if len(s) < max_len:
87 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
88 |
89 | # output mean of each column of scores
90 | col_means= np.array(score_mat).mean(axis=0)
91 | mean_score = list(np.round(col_means * 100, decimals=1))
92 |
93 | result_str = f"Num samples: {len(samples)}\n" \
94 | f"Num scores: {len(scores)}\n" \
95 | f"Timeout samples: {timeout_cnt}\n" \
96 | f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
97 | f"Mean score: {mean_score}\n"
98 |
99 | # each type score
100 | if "type" in samples[0]:
101 | type_scores = {}
102 | for sample in samples:
103 | if sample['type'] not in type_scores:
104 | type_scores[sample['type']] = []
105 | type_scores[sample['type']].append(sample['score'][-1])
106 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()}
107 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])}
108 | result_str += f"Type scores: {type_scores}\n"
109 |
110 | print(result_str)
111 | return result_str
112 |
113 |
114 | def parse_args():
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument("--data_name", type=str, default="math")
117 | parser.add_argument("--prompt_type", type=str, default="tora")
118 | parser.add_argument("--file_path", type=str, default=None, required=True)
119 | parser.add_argument("--max_num_samples", type=int, default=None)
120 | parser.add_argument("--execute", action="store_true")
121 | args = parser.parse_args()
122 | return args
123 |
124 | if __name__ == "__main__":
125 | args = parse_args()
126 | evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path,
127 | max_num_samples=args.max_num_samples, execute=args.execute)
128 |
--------------------------------------------------------------------------------
/inference/eval/grader.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
4 | - https://github.com/openai/prm800k
5 | """
6 | import multiprocessing
7 | from math import isclose
8 | from typing import Union
9 |
10 | from sympy import simplify, N
11 | from sympy.parsing.sympy_parser import parse_expr
12 | from sympy.parsing.latex import parse_latex
13 |
14 |
15 | def is_digit(s):
16 | try:
17 | float(str(s).replace(",", ""))
18 | return True
19 | except ValueError:
20 | return False
21 |
22 | def math_equal(prediction: Union[bool, float, str],
23 | reference: Union[float, str],
24 | include_percentage: bool = True,
25 | is_close: bool = True,
26 | timeout: bool = False,
27 | ) -> bool:
28 | """
29 | Exact match of math if and only if:
30 | 1. numerical equal: both can convert to float and are equal
31 | 2. symbolic equal: both can convert to sympy expression and are equal
32 | """
33 | try: # 1. numerical equal
34 | if is_digit(prediction) and is_digit(reference):
35 | prediction = float(str(prediction).replace(",", ""))
36 | reference = float(str(reference).replace(",", ""))
37 | # number questions
38 | if include_percentage:
39 | gt_result = [reference / 100, reference, reference * 100]
40 | else:
41 | gt_result = [reference]
42 | for item in gt_result:
43 | try:
44 | if is_close:
45 | if isclose(item, prediction, rel_tol=1e-4):
46 | return True
47 | else:
48 | if item == prediction:
49 | return True
50 | except Exception:
51 | continue
52 | return False
53 | except:
54 | pass
55 |
56 | if not prediction and prediction not in [0, False]:
57 | return False
58 |
59 | # 2. symbolic equal
60 | reference = str(reference).strip()
61 | prediction = str(prediction).strip()
62 |
63 | ## deal with [], (), {}
64 | pred_str, ref_str = prediction, reference
65 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
66 | (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
67 | pred_str = pred_str.strip("[]()")
68 | ref_str = ref_str.strip("[]()")
69 | for s in ['{', "}", "(", ")"]:
70 | ref_str = ref_str.replace(s, "")
71 | pred_str = pred_str.replace(s, "")
72 | if pred_str == ref_str:
73 | return True
74 |
75 | ## [a, b] vs. [c, d], return a==c and b==d
76 | if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
77 | (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
78 | pred_parts = prediction[1:-1].split(",")
79 | ref_parts = reference[1:-1].split(",")
80 | if len(pred_parts) == len(ref_parts):
81 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
82 | return True
83 |
84 | # symbolic equal with sympy
85 | if timeout:
86 | if call_with_timeout(symbolic_equal_process, prediction, reference):
87 | return True
88 | else:
89 | if symbolic_equal(prediction, reference):
90 | return True
91 |
92 | return False
93 |
94 |
95 | def math_equal_process(param):
96 | return math_equal(param[-2], param[-1])
97 |
98 |
99 | def symbolic_equal(a, b):
100 | def _parse(s):
101 | for f in [parse_latex, parse_expr]:
102 | try:
103 | return f(s)
104 | except:
105 | pass
106 | return s
107 | a = _parse(a)
108 | b = _parse(b)
109 |
110 | try:
111 | if simplify(a-b) == 0:
112 | return True
113 | except:
114 | pass
115 |
116 | try:
117 | if isclose(N(a), N(b), rel_tol=1e-3):
118 | return True
119 | except:
120 | pass
121 | return False
122 |
123 |
124 | def symbolic_equal_process(a, b, output_queue):
125 | result = symbolic_equal(a, b)
126 | output_queue.put(result)
127 |
128 |
129 | def call_with_timeout(func, *args, timeout=1, **kwargs):
130 | output_queue = multiprocessing.Queue()
131 | process_args = args + (output_queue,)
132 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
133 | process.start()
134 | process.join(timeout)
135 |
136 | if process.is_alive():
137 | process.terminate()
138 | process.join()
139 | return False
140 |
141 | return output_queue.get()
142 |
143 |
144 | def _test_math_equal():
145 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
146 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
147 | print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
148 |
149 | if __name__ == "__main__":
150 | _test_math_equal()
151 |
--------------------------------------------------------------------------------
/inference/infer_data/annotate_data.py:
--------------------------------------------------------------------------------
1 | """
2 | This scrip evaluates the math data to check their correctiveness.
3 | """
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 | from pebble import ProcessPool
8 | from concurrent.futures import TimeoutError
9 |
10 | from eval.grader import *
11 | from utils.parser import *
12 | from utils.utils import load_jsonl
13 | from utils.python_executor import PythonExecutor
14 |
15 |
16 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, execute=False):
17 | assert samples or file_path, "samples or file_path must be provided"
18 | if not samples:
19 | samples = list(load_jsonl(file_path))
20 | # dedup by idx
21 | if 'idx' in samples[0]:
22 | samples = {sample['idx']: sample for sample in samples}.values()
23 | samples = sorted(samples, key=lambda x: x['idx'])
24 | else:
25 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
26 |
27 | # parse gt if not in the dataset
28 | if 'gt' in samples[0]:
29 | pass
30 | else:
31 | for sample in samples:
32 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
33 |
34 | # execute
35 | if ('pred' not in samples[0]) or execute:
36 | if "pal" in prompt_type:
37 | executor = PythonExecutor(get_answer_expr="solution()")
38 | else:
39 | executor = PythonExecutor(get_answer_from_stdout=True)
40 |
41 | for sample in tqdm(samples, desc="Execute"):
42 | sample['pred'] = []
43 | sample['report'] = []
44 | for code in sample['code']:
45 | pred, report = run_execute(executor, code, prompt_type, execute=True)
46 | sample['pred'].append(pred)
47 | sample['report'].append(report)
48 |
49 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]
50 |
51 | scores = []
52 | timeout_cnt = 0
53 |
54 | with ProcessPool() as pool:
55 | future = pool.map(math_equal_process, params, timeout=10)
56 | iterator = future.result()
57 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
58 | while True:
59 | try:
60 | result = next(iterator)
61 | scores.append(result)
62 | except StopIteration:
63 | break
64 | except TimeoutError as error:
65 | print(error)
66 | scores.append(False)
67 | timeout_cnt += 1
68 | except Exception as error:
69 | print(error.traceback)
70 | exit()
71 | progress_bar.update(1)
72 |
73 | idx = 0
74 | score_mat = []
75 | for sample in samples:
76 | sample['score'] = scores[idx: idx+len(sample['pred'])]
77 | assert len(sample['score']) == len(sample['pred'])
78 | score_mat.append(sample['score'])
79 | idx += len(sample['pred'])
80 |
81 | max_len = max([len(s) for s in score_mat])
82 |
83 | for i, s in enumerate(score_mat):
84 | if len(s) < max_len:
85 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
86 |
87 | # output mean of each column of scores
88 | col_means= np.array(score_mat).mean(axis=0)
89 | mean_score = list(np.round(col_means * 100, decimals=1))
90 |
91 | result_str = f"Num samples: {len(samples)}\n" \
92 | f"Num scores: {len(scores)}\n" \
93 | f"Timeout samples: {timeout_cnt}\n" \
94 | f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
95 | f"Mean score: {mean_score}\n"
96 |
97 | # each type score
98 | if "type" in samples[0]:
99 | type_scores = {}
100 | for sample in samples:
101 | if sample['type'] not in type_scores:
102 | type_scores[sample['type']] = []
103 | type_scores[sample['type']].append(sample['score'][-1])
104 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()}
105 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])}
106 | result_str += f"Type scores: {type_scores}\n"
107 |
108 | print(result_str)
109 | return result_str, samples
110 |
111 |
112 | def parse_args():
113 | parser = argparse.ArgumentParser()
114 | parser.add_argument("--data_name", type=str, default="math")
115 | parser.add_argument("--prompt_type", type=str, default="tora")
116 | parser.add_argument("--file_path", type=str, default=None, required=True)
117 | parser.add_argument("--max_num_samples", type=int, default=None)
118 | parser.add_argument("--execute", action="store_true")
119 | parser.add_argument("--output_dir", type=str, default=None, required=True)
120 | args = parser.parse_args()
121 | return args
122 |
123 | # data_name='gsm8k'
124 | # prompt_type = 'cot' / 'tora'
125 | #
126 | args = parse_args()
127 | eval_result, all_samples = evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path, execute=args.execute)
128 |
129 | with open(args.output_dir, "w", encoding="utf8") as f:
130 | for i in range(len(all_samples)):
131 | json.dump(all_samples[i], f, ensure_ascii=False)
132 | f.write('\n')
133 |
134 |
--------------------------------------------------------------------------------
/inference/infer_data/get_dpo_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, Dataset
2 |
3 | import random
4 | import json
5 |
6 | # Define the file path and output dir
7 | data_files = 'here.jsonl'
8 | output_dir = "huggingface/dataset_name"
9 |
10 | # Load the dataset
11 | ds0 = load_dataset('json', data_files='here.jsonl', split='train')
12 |
13 | def parse_conversation(example):
14 | # Split the data into turns based on the start_of_turn and end_of_turn markers
15 | data = example["my_solu"][0]
16 | turns = re.split(r"|", data)
17 |
18 | # Clean and filter out empty entries
19 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")]
20 |
21 | # Create a list to hold the parsed conversation in the desired format
22 | conversation = []
23 |
24 | # Process each turn, assigning the correct role and content
25 | for turn in turns:
26 | if turn.startswith("user\n"):
27 | # Extract content after the role identifier
28 | content = turn[5:].strip()
29 | conversation.append({"role": "user", "content": content})
30 | elif turn.startswith("model\n"):
31 | content = turn[6:].strip()
32 | conversation.append({"role": "assistant", "content": content})
33 |
34 | return {"messages": conversation}
35 |
36 | # we first transform the data into standard format
37 | ds1 = ds0.map(parse_conversation, num_proc=32)
38 |
39 |
40 | def filter_example(example):
41 | old_messages = example["messages"]
42 |
43 | if len(old_messages) < 4:
44 | return False
45 |
46 | if len(old_messages) % 2 != 0:
47 | return False
48 |
49 | all_mes_len = len(old_messages)
50 | # if the model makes mistake but predict the correct answer
51 | if "error" in old_messages[-2]["content"].lower():
52 | return False
53 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower():
54 | return False
55 |
56 | k = 0
57 |
58 | for mes in old_messages:
59 | if k % 2 != 0 and k < all_mes_len - 1:
60 | if "python" not in mes["content"]:
61 | return False
62 | k += 1
63 | # env error
64 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower():
65 | return False
66 |
67 | return True
68 |
69 | ds2 = ds1.filter(filter_example, num_proc=32)
70 |
71 |
72 | # Function to de-duplicate and group entries
73 | def deduplicate_and_group(dataset):
74 | unique_entries = {}
75 | for entry in dataset:
76 | idx = entry['idx']
77 | solu = entry['my_solu'][0] # tuples are hashable and can be used as dictionary keys
78 | if idx not in unique_entries:
79 | unique_entries[idx] = {}
80 | if solu not in unique_entries[idx]:
81 | unique_entries[idx][solu] = entry
82 | return unique_entries
83 |
84 | # Group by 'idx' and de-duplicate by 'my_solu'
85 | grouped_data = deduplicate_and_group(ds2)
86 |
87 | # Select one sample with scores [True] and one with scores [False]
88 | def select_samples(groups):
89 | selected_pairs = []
90 | for idx, solutions in groups.items():
91 | true_samples = [sol for sol in solutions.values() if sol['score'][0] == True]
92 | false_samples = [sol for sol in solutions.values() if sol['score'][0] == False]
93 | if true_samples and false_samples:
94 | selected_true = random.choice(true_samples)
95 | selected_false = random.choice(false_samples)
96 | selected_pairs.append((selected_true, selected_false))
97 | return selected_pairs
98 |
99 | # Apply the selection function
100 | selected_samples = select_samples(grouped_data)
101 |
102 |
103 | # get the dpo dataset
104 | all_samples = []
105 | for pair in selected_samples:
106 | all_samples.append(
107 | {
108 | "gt": pair[0]["gt"],
109 | "chosen": pair[0]["messages"],
110 | "rejected": pair[1]["messages"],}
111 | )
112 |
113 | dict_data = {
114 | "rejected": [d['rejected'] for d in all_samples],
115 | "chosen": [d['chosen'] for d in all_samples],
116 | "gt": [d['gt'] for d in all_samples],
117 | }
118 |
119 | final_ds = Dataset.from_dict(dict_data)
120 | final_ds.push_to_hub(output_dir)
121 |
--------------------------------------------------------------------------------
/inference/infer_data/infer_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | This scrip support is adapted from Tora project and supports multi-rounds vllm inference.
3 | The inference is formulated as a multi-turn chat and the model should be registered as a server by scripts/register_server.sh first.
4 | """
5 |
6 | import argparse
7 | import os
8 | import random
9 | import time
10 | from concurrent.futures import ThreadPoolExecutor, as_completed
11 | from datetime import datetime
12 |
13 | import requests
14 | from eval.evaluate import evaluate
15 | from tqdm import tqdm
16 | from utils.data_loader import load_data
17 | from utils.parser import *
18 | from utils.python_executor import PythonExecutor
19 | from utils.utils import construct_prompt, load_jsonl, save_jsonl, set_seed
20 | from vllm import LLM, SamplingParams
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--data_name", default="gsm8k", type=str)
26 | parser.add_argument("--data_dir", default="./data", type=str)
27 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
28 | parser.add_argument("--output_dir", default="./output", type=str)
29 | parser.add_argument("--prompt_type", default="tora", type=str)
30 | parser.add_argument("--split", default="test", type=str)
31 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
32 | parser.add_argument("--seed", default=0, type=int)
33 | parser.add_argument("--start", default=0, type=int)
34 | parser.add_argument("--end", default=-1, type=int)
35 | parser.add_argument("--temperature", default=0, type=float)
36 | parser.add_argument("--n_sampling", default=1, type=int)
37 | parser.add_argument("--top_p", default=1, type=float)
38 | parser.add_argument("--max_tokens_per_call", default=1024, type=int)
39 | parser.add_argument("--shuffle", action="store_true")
40 | parser.add_argument("--ports", action='append', default=[])
41 | parser.add_argument("--horizon", default=6, type=int) # the maximal number of tool calls
42 | parser.add_argument("--eval", default=False, type=bool)
43 |
44 | args = parser.parse_args()
45 | args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm)
46 | return args
47 |
48 |
49 | def prepare_data(args):
50 | examples = load_data(args.data_name, args.split, args.data_dir)
51 |
52 | # sample `num_test_sample` from dataset
53 | if args.num_test_sample > 0:
54 | examples = random.sample(examples, args.num_test_sample)
55 | elif args.num_test_sample == -1:
56 | args.num_test_sample = len(examples)
57 |
58 | # shuffle
59 | if args.shuffle:
60 | random.seed(datetime.now().timestamp())
61 | random.shuffle(examples)
62 |
63 | # select start and end
64 | if args.end == -1:
65 | args.end = len(examples)
66 | examples = examples[args.start : args.end]
67 |
68 | # get out_file name
69 | dt_string = datetime.now().strftime("%m-%d_%H-%M")
70 | model_name = "/".join(args.model_name_or_path.split("/")[-2:])
71 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}"
72 | out_file = f"{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl"
73 | os.makedirs(f"{args.output_dir}/{model_name}/{args.data_name}", exist_ok=True)
74 |
75 | # load all processed samples
76 | # find the files in e.g. ./output/gemma2/math/
77 | processed_files = [
78 | f
79 | for f in os.listdir(f"{args.output_dir}/{model_name}/{args.data_name}/")
80 | if f.endswith(".jsonl") and f.startswith(out_file_prefix)
81 | ]
82 | processed_samples = []
83 | for f in processed_files:
84 | processed_samples.extend(list(load_jsonl(f"{args.output_dir}/{model_name}/{args.data_name}/{f}")))
85 |
86 | # dedepulicate
87 | processed_samples = {sample["idx"]: sample for sample in processed_samples}
88 | processed_idxs = list(processed_samples.keys())
89 | processed_samples = list(processed_samples.values())
90 | total_examples = len(examples)
91 | # if example has been inferenced with the same seed, temperature, and model, we skip them
92 | examples = [example for example in examples if example["idx"] not in processed_idxs]
93 | print(f"Idx {args.start} - {args.end}: Remain {len(examples)}/{total_examples} samples.")
94 | if len(examples) == 0:
95 | pass
96 | else:
97 | print(examples[0])
98 | return examples, processed_samples, out_file
99 |
100 |
101 | def main(args):
102 | ports = args.ports
103 | examples, processed_samples, out_file = prepare_data(args)
104 | # init python executor
105 | executor = PythonExecutor(get_answer_from_stdout=True)
106 | print(args.prompt_type)
107 |
108 | SamplingParams.seed = args.seed
109 | # load model and determine the number of gpus used
110 | if "gemma" in args.model_name_or_path:
111 | stop_tokens = ["", "", "```output", ""]
112 | elif "mistral" in args.model_name_or_path:
113 | stop_tokens = ["", "", "[INST]", "```output"]
114 | elif "deepseek" in args.model_name_or_path:
115 | stop_tokens = ["<|end▁of▁sentence|>", "User", "```output"]
116 | elif "llama3" in args.model_name_or_path:
117 | stop_tokens = ["<|eot_id|>", "<|start_header_id|>user", "```output"]
118 | else:
119 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path)
120 | default_args = {
121 | "use_beam_search": False,
122 | "n": 1,
123 | "temperature": args.temperature,
124 | "max_tokens": 1024,
125 | "seed": args.seed,
126 | "top_p": 1.0,
127 | "top_k": -1,
128 | "stop": stop_tokens,
129 | }
130 |
131 | def query_model(prompt, args, port):
132 | json = {
133 | **args,
134 | "prompt": prompt,
135 | }
136 | response = requests.post(url="http://localhost" + ":" + str(port) + "/generate", json=json)
137 | response_json = response.json()
138 | return [response_json["text"][i][len(prompt) :] for i in range(len(response_json["text"]))]
139 |
140 | samples = []
141 |
142 | for example in tqdm(examples, total=len(examples)):
143 | idx = example["idx"]
144 |
145 | # parse question and answer
146 | example["question"] = parse_question(example, args.data_name)
147 | gt_cot, gt_ans = parse_ground_truth(example, args.data_name)
148 |
149 | full_prompt = construct_prompt(args, example)
150 |
151 | sample = {"idx": idx, "question": example["question"], "gt_cot": gt_cot, "gt": gt_ans, "prompt": full_prompt}
152 | # add remain fields
153 | for key in [
154 | "level",
155 | "type",
156 | "unit",
157 | "solution_type",
158 | "choices",
159 | "solution",
160 | "ques_type",
161 | "ans_type",
162 | "answer_type",
163 | "dataset",
164 | "subfield",
165 | "filed",
166 | "theorem",
167 | "answer",
168 | ]:
169 | if key in example:
170 | sample[key] = example[key]
171 | samples.append(sample)
172 |
173 | print("dataset:", args.data_name, "samples:", len(samples))
174 | if len(samples) > 0:
175 | print("-" * 50)
176 | print("sample:", samples[0]["prompt"])
177 | print("-" * 50)
178 |
179 | # repeat H times
180 | remain_prompts = [sample["prompt"] for sample in samples for _ in range(args.n_sampling)]
181 | remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
182 | all_gts = [sample["gt"] for sample in samples for _ in range(args.n_sampling)]
183 |
184 | tmp_idx = list(range(len(all_gts)))
185 | all_gts = dict(zip(tmp_idx, all_gts))
186 |
187 | end_prompts = []
188 |
189 | max_func_call = 1 if args.prompt_type == "cot" else args.horizon
190 |
191 | # start inference, measure time use
192 | start_time = time.time()
193 | print("The maxmial function call is ", max_func_call)
194 | for epoch in range(max_func_call):
195 | print("=" * 50, "Epoch", epoch)
196 | current_prompts = remain_prompts
197 | # if all the queries meet the stop criteria, break
198 | if len(current_prompts) == 0:
199 | break
200 |
201 | # get all outputs, each prompt is (idx, prompt_content)
202 | prompts = [item[1] for item in current_prompts]
203 | with ThreadPoolExecutor(512) as executor2:
204 | result = [
205 | executor2.submit(query_model, prompts[i], default_args, ports[i % len(ports)])
206 | for i in range(len(prompts))
207 | ]
208 | # use tqdm to show progress
209 | for _ in tqdm(as_completed(result), total=len(result)):
210 | pass
211 |
212 | outputs = [r.result()[0] for r in result]
213 |
214 | # print(len(outputs), len(current_prompts))
215 |
216 | if len(outputs) != len(current_prompts):
217 | raise ValueError("VLLM has some problem, the generated responsess are less than the queries.")
218 |
219 | # process all outputs
220 | remain_prompts = []
221 | remain_codes = []
222 |
223 | for (i, query), output in zip(current_prompts, outputs):
224 | output = output.rstrip()
225 | # append the y_s to the current state (history)
226 | query += output
227 | if args.prompt_type == "cot":
228 | # for cot, the prompt ends for one round
229 | end_prompts.append((i, query))
230 | elif "boxed" not in output and "```python" in output: #output.endswith("```"):
231 | # the model does not output the final answer, meanwhile, a code needs to be executed
232 | program = extract_program(query)
233 | remain_prompts.append((i, query))
234 | remain_codes.append(program)
235 | else:
236 | end_prompts.append((i, query))
237 |
238 | # execute the codes and get the results
239 | # note that the order of remain_codes is the same as remain_prompts
240 | remain_results = executor.batch_apply(remain_codes)
241 | for k in range(len(remain_prompts)):
242 | i, query = remain_prompts[k]
243 | res, report = remain_results[k]
244 | exec_result = res if res else report
245 | # we add the observation to the history
246 | if "gemma" in args.model_name_or_path:
247 | exec_result = f"\nuser\n```output\n{exec_result}\n```\nmodel\n"
248 | elif "mistral" in args.model_name_or_path:
249 | exec_result = f" [INST] ```output\n{exec_result}\n``` [/INST]"
250 | elif "deepseek" in args.model_name_or_path:
251 | #exec_result = f"<|end▁of▁sentence|>User: ```output\n{exec_result}\n```\n\nAssistant:"
252 | #for deepseek, we directly append the observation as the training of deepseek
253 | exec_result = f"\n```output\n{exec_result}\n```\n"
254 | elif "llama3" in args.model_name_or_path:
255 | exec_result = f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n```output\n{exec_result}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
256 | else:
257 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path)
258 |
259 | query += exec_result
260 |
261 | if epoch == max_func_call - 1:
262 | query += "\nReach max function call limit."
263 | remain_prompts[k] = (i, query)
264 |
265 | # unsolved samples
266 | print("Unsolved samples:", len(remain_prompts))
267 | end_prompts.extend(remain_prompts)
268 | # sort by idx
269 | end_prompts = sorted(end_prompts, key=lambda x: x[0])
270 |
271 | if "gemma" in args.model_name_or_path:
272 | ans_split = "model\n"
273 | elif "mistral" in args.model_name_or_path:
274 | ans_split = "[/INST]"
275 | elif "deepseek" in args.model_name_or_path:
276 | ans_split = "\n\nAssistant:"
277 | elif "llama3" in args.model_name_or_path:
278 | ans_split = "<|start_header_id|>user<|end_header_id|>\n\n"
279 | else:
280 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path)
281 |
282 | codes = [prompt.split(ans_split)[-1].strip() for _, prompt in end_prompts]
283 |
284 | # extract preds, run_execute will extract the code needed to run...
285 | # for tora, we only extract the final answer but do not run the code
286 | results = [run_execute(executor, code, args.prompt_type) for code in codes]
287 |
288 | time_use = time.time() - start_time
289 | tmp_to_store = [z.split("---")[-1].strip() for _, z in end_prompts]
290 | # put results back to examples
291 | all_samples = []
292 | for i, sample in enumerate(samples):
293 | code = codes[i * args.n_sampling : (i + 1) * args.n_sampling]
294 | result = results[i * args.n_sampling : (i + 1) * args.n_sampling]
295 | preds = [item[0] for item in result]
296 | reports = [item[1] for item in result]
297 | response_tmp = tmp_to_store[i * args.n_sampling : (i + 1) * args.n_sampling]
298 | sample.pop("prompt")
299 | sample.update({"my_solu": response_tmp, "code": code, "pred": preds, "report": reports})
300 | all_samples.append(sample)
301 |
302 | # add processed samples
303 | all_samples.extend(processed_samples)
304 | save_jsonl(all_samples, out_file)
305 |
306 | # Evaluate the result
307 | if args.eval:
308 | result_str = evaluate(samples=all_samples, data_name=args.data_name, prompt_type=args.prompt_type, execute=True)
309 | result_str += f"\nTime use: {time_use:.2f}s"
310 | time_str = f"{int(time_use // 60)}:{int(time_use % 60):02d}"
311 | result_str += f"\nTime use: {time_str}"
312 |
313 | with open(out_file.replace(".jsonl", f"_{args.prompt_type}.metrics"), "w") as f:
314 | f.write(result_str)
315 |
316 |
317 | if __name__ == "__main__":
318 | args = parse_args()
319 | set_seed(args.seed)
320 | main(args)
321 |
--------------------------------------------------------------------------------
/inference/scripts/eval.sh:
--------------------------------------------------------------------------------
1 |
2 | if [ $# -eq 0 ]; then
3 | echo "Usage: $0 "
4 | exit 1
5 | fi
6 | MODEL_NAME_OR_PATH=$1
7 |
8 | # DATA_LIST = ['math', 'gsm8k', 'gsm-hard', 'svamp', 'tabmwp', 'asdiv', 'mawps']
9 |
10 | DATA_NAME="gsm8k"
11 |
12 | OUTPUT_DIR="./output1"
13 |
14 | SPLIT="test"
15 | PROMPT_TYPE="tora"
16 | NUM_TEST_SAMPLE=-1
17 |
18 |
19 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \
20 | python -um infer_data.infer_eval \
21 | --model_name_or_path ${MODEL_NAME_OR_PATH} \
22 | --data_name ${DATA_NAME} \
23 | --output_dir ${OUTPUT_DIR} \
24 | --split ${SPLIT} \
25 | --prompt_type ${PROMPT_TYPE} \
26 | --num_test_sample ${NUM_TEST_SAMPLE} \
27 | --seed 1 \
28 | --temperature 0 \
29 | --n_sampling 1 \
30 | --top_p 1 \
31 | --start 0 \
32 | --end -1 \
33 | --horizon 4 \
34 | --ports "8000" \
35 | --ports "8001" \
36 | --ports "8002" \
37 | --ports "8003" \
38 | --ports "8004" \
39 | --ports "8005" \
40 | --ports "8006" \
41 | --ports "8007" \
42 | --eval eval \
43 |
44 |
45 |
46 | DATA_NAME="math"
47 |
48 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \
49 | python -um infer_data.infer_eval \
50 | --model_name_or_path ${MODEL_NAME_OR_PATH} \
51 | --data_name ${DATA_NAME} \
52 | --output_dir ${OUTPUT_DIR} \
53 | --split ${SPLIT} \
54 | --prompt_type ${PROMPT_TYPE} \
55 | --num_test_sample ${NUM_TEST_SAMPLE} \
56 | --seed 1 \
57 | --temperature 0 \
58 | --n_sampling 1 \
59 | --top_p 1 \
60 | --start 0 \
61 | --end -1 \
62 | --horizon 4 \
63 | --ports "8000" \
64 | --ports "8001" \
65 | --ports "8002" \
66 | --ports "8003" \
67 | --ports "8004" \
68 | --ports "8005" \
69 | --ports "8006" \
70 | --ports "8007" \
71 | --eval eval \
72 |
--------------------------------------------------------------------------------
/inference/scripts/infer.sh:
--------------------------------------------------------------------------------
1 | if [ $# -eq 0 ]; then
2 | echo "Usage: $0 "
3 | exit 1
4 | fi
5 | MODEL_NAME_OR_PATH=$1
6 |
7 | #DATA_LIST = ['math', 'gsm8k']
8 |
9 | DATA_NAME="gsm8k"
10 |
11 | OUTPUT_DIR="./collect_data"
12 |
13 | SPLIT="train"
14 | PROMPT_TYPE="tora"
15 | NUM_TEST_SAMPLE=-1
16 |
17 |
18 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \
19 | python -um infer_data.infer_eval \
20 | --model_name_or_path ${MODEL_NAME_OR_PATH} \
21 | --data_name ${DATA_NAME} \
22 | --output_dir ${OUTPUT_DIR} \
23 | --split ${SPLIT} \
24 | --prompt_type ${PROMPT_TYPE} \
25 | --num_test_sample ${NUM_TEST_SAMPLE} \
26 | --seed 1 \
27 | --temperature 0 \
28 | --n_sampling 1 \
29 | --top_p 1 \
30 | --start 0 \
31 | --end -1 \
32 | --horizon 6 \
33 | --ports "8000" \
34 | --ports "8001" \
35 | --ports "8002" \
36 | --ports "8003" \
37 | --ports "8004" \
38 | --ports "8005" \
39 | --ports "8006" \
40 | --ports "8007" \
41 |
42 |
43 |
--------------------------------------------------------------------------------
/inference/scripts/iter_infer_to_collect_data.sh:
--------------------------------------------------------------------------------
1 | if [ $# -eq 0 ]; then
2 | echo "Usage: $0 "
3 | exit 1
4 | fi
5 | MODEL_NAME_OR_PATH=$1
6 |
7 | #DATA_NAME="math"
8 | DATA_NAME="gsm8k"
9 | OUTPUT_DIR="./iter_collect_data"
10 |
11 | SPLIT="train"
12 | PROMPT_TYPE="tora"
13 | NUM_TEST_SAMPLE=-1
14 |
15 | for ((i=0; i<=8000; i+=1))
16 | do
17 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \
18 | python -um infer_data.infer_eval \
19 | --model_name_or_path ${MODEL_NAME_OR_PATH} \
20 | --data_name ${DATA_NAME} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --split ${SPLIT} \
23 | --prompt_type ${PROMPT_TYPE} \
24 | --num_test_sample ${NUM_TEST_SAMPLE} \
25 | --seed 1 \
26 | --temperature 0 \
27 | --n_sampling 1 \
28 | --top_p 1 \
29 | --start 0 \
30 | --end -1 \
31 | --horizon 6 \
32 | --ports "8000" \
33 | --ports "8001" \
34 | --ports "8002" \
35 | --ports "8003" \
36 | --ports "8004" \
37 | --ports "8005" \
38 | --ports "8006" \
39 | --ports "8007"
40 | done
41 |
--------------------------------------------------------------------------------
/inference/scripts/register_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # check whether model path is provided
4 | if [ $# -eq 0 ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 |
10 | MODEL_PATH=$1
11 |
12 | # generate 8 server instances
13 | for i in {0..7}
14 | do
15 | CUDA_VISIBLE_DEVICES=$i python -m vllm.entrypoints.api_server \
16 | --model $MODEL_PATH \
17 | --gpu-memory-utilization=0.9 \
18 | --max-num-seqs=200 \
19 | --host 127.0.0.1 --tensor-parallel-size 1 \
20 | --port $((8000+i)) \
21 | &
22 | done
23 |
--------------------------------------------------------------------------------
/inference/utils/annotate_data.py:
--------------------------------------------------------------------------------
1 | # The evaluator is adapted from the ToRA project
2 | # https://github.com/microsoft/ToRA
3 | # ToRA authors: Zhibin Gou and Zhihong Shao and Yeyun Gong and yelong shen and Yujiu Yang and Minlie Huang and Nan Duan and Weizhu Chen
4 |
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pebble import ProcessPool
9 | from concurrent.futures import TimeoutError
10 |
11 | from eval.grader import *
12 | from utils.parser import *
13 | from utils.utils import load_jsonl
14 | from utils.python_executor import PythonExecutor
15 | from datasets import load_dataset, Dataset
16 |
17 | def evaluate(data_name, output_dir=None):
18 | if ".json" in data_name:
19 | ds = load_dataset("json", data_files=data_name, split='train', field='instances').shuffle(seed=42)
20 | else:
21 | ds = load_dataset(data_name, split="train").shuffle(seed=42)
22 |
23 | samples = [sample for sample in ds]
24 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]
25 |
26 | scores = []
27 | timeout_cnt = 0
28 |
29 | with ProcessPool() as pool:
30 | future = pool.map(math_equal_process, params, timeout=10)
31 | iterator = future.result()
32 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
33 | while True:
34 | try:
35 | result = next(iterator)
36 | scores.append(result)
37 | except StopIteration:
38 | break
39 | except TimeoutError as error:
40 | print(error)
41 | scores.append(False)
42 | timeout_cnt += 1
43 | except Exception as error:
44 | print(error.traceback)
45 | exit()
46 | progress_bar.update(1)
47 |
48 | idx = 0
49 | score_mat = []
50 | for sample in samples:
51 | sample['score'] = scores[idx: idx+len(sample['pred'])]
52 | assert len(sample['score']) == len(sample['pred'])
53 |
54 | if ".json" in args.output_dir:
55 | all_data = [sample for sample in samples]
56 | output_eval_dataset = {}
57 | output_eval_dataset["type"] = "text_only"
58 | output_eval_dataset["instances"] = all_data
59 | print("I collect ", len(all_data), "samples")
60 | with open(output_dir, "w", encoding="utf8") as f:
61 | json.dump(output_eval_dataset, f, ensure_ascii=False)
62 | else:
63 | all_data = [sample for sample in samples]
64 | dict_data = {
65 | "idx": [d['idx'] for d in all_data],
66 | "gt": [d['gt'] for d in all_data],
67 | "level": [d['level'] for d in all_data],
68 | "type": [d['type'] for d in all_data],
69 | "messages": [d['messages'] for d in all_data],
70 | "pred": [d['pred'] for d in all_data],
71 | "score": [d['score'] for d in all_data],
72 | }
73 |
74 | dataset = Dataset.from_dict(dict_data)
75 | DatasetDict({'train': dataset}).push_to_hub(output_dir)
76 |
77 |
78 | def parse_args():
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument("--data_name", type=str, default="math")
81 | parser.add_argument("--output_dir", type=str, default=None, required=True)
82 | args = parser.parse_args()
83 | return args
84 |
85 | if __name__ == "__main__":
86 | args = parse_args()
87 | evaluate(data_name=args.data_name, output_dir=args.output_dir)
88 |
--------------------------------------------------------------------------------
/inference/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from datasets import load_dataset, Dataset, concatenate_datasets
5 | from utils.utils import load_jsonl, lower_keys
6 |
7 | def load_data(data_name, split, data_dir='./data'):
8 | data_file = f"{data_dir}/{data_name}/{split}.jsonl"
9 | if os.path.exists(data_file):
10 | examples = list(load_jsonl(data_file))
11 | else:
12 | if data_name == "math":
13 | dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp")
14 | elif "RLHF4MATH/prompt_iter" in data_name:
15 | # if we use the pre-processed prompts
16 | dataset = load_dataset(data_name, split=split)
17 | elif data_name == "theorem-qa":
18 | dataset = load_dataset("wenhu/TheoremQA", split=split)
19 | elif data_name == "gsm8k":
20 | dataset = load_dataset(data_name, split=split)
21 | elif data_name == "gsm-hard":
22 | dataset = load_dataset("reasoning-machines/gsm-hard", split="train")
23 | elif data_name == "svamp":
24 | # evaluate on training set + test set
25 | dataset = load_dataset("ChilleD/SVAMP", split="train")
26 | dataset = concatenate_datasets([dataset, load_dataset("ChilleD/SVAMP", split="test")])
27 | elif data_name == "asdiv":
28 | dataset = load_dataset("EleutherAI/asdiv", split="validation")
29 | dataset = dataset.filter(lambda x: ";" not in x['answer']) # remove multi-answer examples
30 | elif data_name == "mawps":
31 | examples = []
32 | # four sub-tasks
33 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]:
34 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl"))
35 | for example in sub_examples:
36 | example['type'] = data_name
37 | examples.extend(sub_examples)
38 | dataset = Dataset.from_list(examples)
39 | elif data_name == "finqa":
40 | dataset = load_dataset("dreamerdeo/finqa", split=split, name="main")
41 | dataset = dataset.select(random.sample(range(len(dataset)), 1000))
42 | elif data_name == "tabmwp":
43 | examples = []
44 | with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f:
45 | data_dict = json.load(f)
46 | examples.extend(data_dict.values())
47 | dataset = Dataset.from_list(examples)
48 | dataset = dataset.select(random.sample(range(len(dataset)), 1000))
49 | elif data_name == "bbh":
50 | examples = []
51 | for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\
52 | "date_understanding", "repeat_copy_logic", "object_counting"]:
53 | with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f:
54 | sub_examples = json.load(f)["examples"]
55 | for example in sub_examples:
56 | example['type'] = data_name
57 | examples.extend(sub_examples)
58 | dataset = Dataset.from_list(examples)
59 | else:
60 | raise NotImplementedError(data_name)
61 |
62 | examples = list(dataset)
63 | examples = [lower_keys(example) for example in examples]
64 | dataset = Dataset.from_list(examples)
65 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
66 | dataset.to_json(data_file)
67 |
68 | # add 'idx' in the first column
69 | if 'idx' not in examples[0]:
70 | examples = [{'idx': i, **example} for i, example in enumerate(examples)]
71 |
72 | # dedepulicate & sort
73 | examples = sorted(examples, key=lambda x: x['idx'])
74 | return examples
75 |
--------------------------------------------------------------------------------
/inference/utils/filter_data.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from datasets import Dataset, concatenate_datasets, load_dataset
3 | from transformers import AutoTokenizer, HfArgumentParser
4 | import argparse
5 |
6 | # Create an ArgumentParser object
7 | parser = argparse.ArgumentParser(description="Process some strings.")
8 |
9 | parser.add_argument("--data_dir", type=str, help="The dataset address", default=None)
10 | parser.add_argument("--output_dir", type=str, help="The output address", default=None)
11 | parser.add_argument("--model_name", type=str, help="The model used to collect the data", default='gemma')
12 |
13 | # Parse command line arguments
14 | args = parser.parse_args()
15 |
16 |
17 | # Step 1: load dataset
18 | if ".json" in args.data_dir:
19 | ds = load_dataset("json", data_files=args.data_dir, split='train', field='instances').shuffle(seed=42)
20 | else:
21 | ds = load_dataset(args.data_dir, split="train").shuffle(seed=42)
22 |
23 | # You may want to merge different datasets...
24 | #ds = concatenate_datasets([ds1, ds2])
25 |
26 | # Step 2: we split the trajectory into the standard multi-turn format
27 |
28 | def parse_conversation(example, model_name='gemma'):
29 | # Split the data into turns based on the start_of_turn and end_of_turn markers
30 |
31 | if 'gemma' in model_name:
32 | data = example["my_solu"][0]
33 | turns = re.split(r"|", data)
34 | elif 'mistral' in model_name:
35 | data = example["my_solu"][0].replace("", "").replace("", "")
36 | turns = re.split(r"\[INST\]|\[/INST\]", data)
37 | else:
38 | raise NotImplementedError(model_name)
39 |
40 | # Clean and filter out empty entries
41 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")]
42 |
43 | # Create a list to hold the parsed conversation in the desired format
44 | conversation = []
45 |
46 | if 'gemma' in model_name:
47 | for turn in turns:
48 | if turn.startswith("user\n"):
49 | # Extract content after the role identifier
50 | content = turn[5:].strip()
51 | conversation.append({"role": "user", "content": content})
52 | elif turn.startswith("model\n"):
53 | content = turn[6:].strip()
54 | conversation.append({"role": "assistant", "content": content}
55 |
56 | elif 'mistral' in model_name:
57 | j = 0
58 | for turn in turns:
59 | if j % 2 == 0:
60 | content = turn.strip()
61 | conversation.append({"role": "user", "content": content})
62 | j += 1
63 | else:
64 | content = turn.strip()
65 | conversation.append({"role": "assistant", "content": content}
66 | j += 1
67 | else:
68 | raise NotImplementedError(model_name)
69 |
70 |
71 | return {"messages": conversation}
72 |
73 |
74 | ds_new = ds.map(parse_conversation, num_proc=32)
75 |
76 | # Step 3: we filter the examples which are with ood rounds, make mistake in the second last round but still give a guess of the result
77 |
78 |
79 | def filter_example1(example):
80 | old_messages = example["messages"]
81 |
82 | if len(old_messages) < 4:
83 | return False
84 |
85 | if len(old_messages) % 2 != 0:
86 | return False
87 |
88 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower():
89 | return False
90 |
91 | for mes in old_messages:
92 | # the code interpreter destroy the conda environment from time to time, we delete the samples collected when the env is wrong
93 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower():
94 | return False
95 | if "```output\n[]" in mes["content"]:
96 | return False
97 |
98 | if "traitlets" in mes['content'] and 'error' in mes['content'].lower():
99 | return False
100 | if "sympy.core.numbers" in mes['content'] and 'error' in mes['content'].lower():
101 | return False
102 | if 'sympy.tensor.tensor' in mes['content'] and 'error' in mes['content']:
103 | return False
104 | if 'minlex() got an' in mes['content']:
105 | return False
106 | if 'No module named' in mes['content'] and 'sympy.' in mes['content']:
107 | return False
108 | if 'object is not subscriptable' in mes['content'].lower():
109 | return False
110 |
111 | # We delete the samples that reach max function call
112 | # it does not influence the final result but can significantly accelerate the training process
113 | if 'Reach max function call' in mes['content']:
114 | return False
115 |
116 | return True
117 |
118 | ds_new = ds_new.filter(filter_example1, num_proc=32)
119 |
120 |
121 | # Step 4: we delete the samples that are too long
122 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
123 |
124 | def filter_too_long_pred(example):
125 | z = len(tokenizer.apply_chat_template(example["messages"], tokenize=True))
126 | if z > 2048:
127 | return False
128 | return True
129 |
130 | ds_new = ds_new.filter(filter_too_long_pred, num_proc=32)
131 |
132 |
133 | # Step 5: output the filtered dataset
134 |
135 | # we delete the columns that are unnecessary
136 | columns_to_keep = ["idx", "gt", "level", "type", "messages", "pred"]
137 | ds_new = ds_new.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
138 |
139 |
140 | if ".json" in args.output_dir:
141 | all_data = [sample for sample in ds_new]
142 | output_eval_dataset = {}
143 | output_eval_dataset["type"] = "text_only"
144 | output_eval_dataset["instances"] = all_data
145 | print("I collect ", len(all_data), "samples")
146 | with open(args.output_dir, "w", encoding="utf8") as f:
147 | json.dump(output_eval_dataset, f, ensure_ascii=False)
148 | else:
149 | ds_new.push_to_hub(args.output_dir)
150 |
151 |
152 |
--------------------------------------------------------------------------------
/inference/utils/parser.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Dict
3 |
4 |
5 | def _fix_fracs(string):
6 | substrs = string.split("\\frac")
7 | new_str = substrs[0]
8 | if len(substrs) > 1:
9 | substrs = substrs[1:]
10 | for substr in substrs:
11 | new_str += "\\frac"
12 | if len(substr) > 0 and substr[0] == "{":
13 | new_str += substr
14 | else:
15 | try:
16 | assert len(substr) >= 2
17 | except:
18 | return string
19 | a = substr[0]
20 | b = substr[1]
21 | if b != "{":
22 | if len(substr) > 2:
23 | post_substr = substr[2:]
24 | new_str += "{" + a + "}{" + b + "}" + post_substr
25 | else:
26 | new_str += "{" + a + "}{" + b + "}"
27 | else:
28 | if len(substr) > 2:
29 | post_substr = substr[2:]
30 | new_str += "{" + a + "}" + b + post_substr
31 | else:
32 | new_str += "{" + a + "}" + b
33 | string = new_str
34 | return string
35 |
36 |
37 | def _fix_a_slash_b(string):
38 | if len(string.split("/")) != 2:
39 | return string
40 | a = string.split("/")[0]
41 | b = string.split("/")[1]
42 | try:
43 | if "sqrt" not in a:
44 | a = int(a)
45 | if "sqrt" not in b:
46 | b = int(b)
47 | assert string == "{}/{}".format(a, b)
48 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
49 | return new_string
50 | except:
51 | return string
52 |
53 |
54 | def _fix_sqrt(string):
55 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
56 | return _string
57 |
58 |
59 | def strip_string(string):
60 | string = str(string).strip()
61 | # linebreaks
62 | string = string.replace("\n", "")
63 |
64 | # right "."
65 | string = string.rstrip(".")
66 |
67 | # remove inverse spaces
68 | string = string.replace("\\!", "")
69 | string = string.replace("\\ ", "")
70 |
71 | # replace \\ with \
72 | string = string.replace("\\\\", "\\")
73 | string = string.replace("\\\\", "\\")
74 |
75 | # replace tfrac and dfrac with frac
76 | string = string.replace("tfrac", "frac")
77 | string = string.replace("dfrac", "frac")
78 |
79 | # remove \left and \right
80 | string = string.replace("\\left", "")
81 | string = string.replace("\\right", "")
82 |
83 | # Remove unit: miles, dollars if after is not none
84 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
85 | if _string != "" and _string != string:
86 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
87 | string = _string
88 |
89 | # Remove circ (degrees)
90 | string = string.replace("^{\\circ}", "")
91 | string = string.replace("^\\circ", "")
92 |
93 | # remove dollar signs
94 | string = string.replace("\\$", "")
95 | string = string.replace("$", "")
96 |
97 | string = string.replace("\\text", "")
98 | string = string.replace("x\\in", "")
99 |
100 | # remove percentage
101 | string = string.replace("\\%", "")
102 | string = string.replace("\%", "")
103 | string = string.replace("%", "")
104 |
105 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
106 | string = string.replace(" .", " 0.")
107 | string = string.replace("{.", "{0.")
108 |
109 | # cdot
110 | string = string.replace("\\cdot", "")
111 |
112 | # inf
113 | string = string.replace("infinity", "\\infty")
114 | if "\\infty" not in string:
115 | string = string.replace("inf", "\\infty")
116 | string = string.replace("+\\inity", "\\infty")
117 |
118 | # and
119 | string = string.replace("and", "")
120 | string = string.replace("\\mathbf", "")
121 |
122 | # use regex to remove \mbox{...}
123 | string = re.sub(r"\\mbox{.*?}", "", string)
124 |
125 | # quote
126 | string.replace("'", "")
127 | string.replace('"', "")
128 |
129 | # i, j
130 | if "j" in string and "i" not in string:
131 | string = string.replace("j", "i")
132 |
133 | # replace a.000b where b is not number or b is end, with ab, use regex
134 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
135 | string = re.sub(r"(\d+)\.0+$", r"\1", string)
136 |
137 | # if empty, return empty string
138 | if len(string) == 0:
139 | return string
140 | if string[0] == ".":
141 | string = "0" + string
142 |
143 | # to consider: get rid of e.g. "k = " or "q = " at beginning
144 | if len(string.split("=")) == 2:
145 | if len(string.split("=")[0]) <= 2:
146 | string = string.split("=")[1]
147 |
148 | string = _fix_sqrt(string)
149 | string = string.replace(" ", "")
150 |
151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
152 | string = _fix_fracs(string)
153 |
154 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
155 | string = _fix_a_slash_b(string)
156 |
157 | return string
158 |
159 |
160 | def extract_answer(pred_str):
161 | if "boxed" in pred_str:
162 | ans = pred_str.split("boxed")[-1]
163 | if len(ans) == 0:
164 | return ""
165 | elif ans[0] == "{":
166 | stack = 1
167 | a = ""
168 | for c in ans[1:]:
169 | if c == "{":
170 | stack += 1
171 | a += c
172 | elif c == "}":
173 | stack -= 1
174 | if stack == 0:
175 | break
176 | a += c
177 | else:
178 | a += c
179 | else:
180 | a = ans.split("$")[0].strip()
181 | pred = a
182 | elif "he answer is" in pred_str:
183 | pred = pred_str.split("he answer is")[-1].strip()
184 | elif extract_program_output(pred_str) != "":
185 | # fall back to program
186 | pred = extract_program_output(pred_str)
187 | else: # use the last number
188 | pattern = "-?\d*\.?\d+"
189 | pred = re.findall(pattern, pred_str.replace(",", ""))
190 | if len(pred) >= 1:
191 | pred = pred[-1]
192 | else:
193 | pred = ""
194 |
195 | # multiple line
196 | pred = pred.split("\n")[0]
197 | if pred != "" and pred[0] == ":":
198 | pred = pred[1:]
199 | if pred != "" and pred[-1] == ".":
200 | pred = pred[:-1]
201 | if pred != "" and pred[-1] == "/":
202 | pred = pred[:-1]
203 | pred = strip_string(pred)
204 | return pred
205 |
206 |
207 | '''
208 | def extract_program(result: str, last_only=True):
209 | """
210 | extract the program after "```python", and before "```"
211 | """
212 | program = ""
213 | start = False
214 | for line in result.split("\n"):
215 | if line.startswith("```python"):
216 | if last_only:
217 | program = "" # only extract the last program
218 | else:
219 | program += "\n# ========\n"
220 | start = True
221 | elif line.startswith("```"):
222 | start = False
223 | elif start:
224 | program += line + "\n"
225 | return program
226 | '''
227 |
228 |
229 | def extract_program(result: str, last_only=False):
230 | """
231 | extract the program after "```python", and before "```"
232 | """
233 | all_program = []
234 | start = False
235 | program = ""
236 | for line in result.split("\n"):
237 | # if line.startswith("```python"):
238 | if "```python" in line:
239 | program = ""
240 | # if last_only:
241 | # program = "" # only extract the last program
242 | # else:
243 | # program += "\n# ========\n"
244 |
245 | start = True
246 | elif line.startswith("```"):
247 | start = False
248 | all_program.append(program)
249 | program = ""
250 | elif start:
251 | program += line + "\n"
252 | return all_program
253 |
254 |
255 | def extract_program_output(pred_str):
256 | """
257 | extract output between the last ```output\n...\n```
258 | """
259 | if "```output" not in pred_str:
260 | return ""
261 | if "```output" in pred_str:
262 | pred_str = pred_str.split("```output")[-1]
263 | if "```" in pred_str:
264 | pred_str = pred_str.split("```")[0]
265 | output = pred_str.strip()
266 | return output
267 |
268 |
269 | def parse_ground_truth(example: Dict[str, Any], data_name):
270 | if "gt_cot" in example:
271 | return example["gt_cot"], strip_string(example["gt"])
272 |
273 | # parse ground truth
274 | if data_name in ["math", "ocw"]:
275 | gt_cot = example["solution"]
276 | gt_ans = extract_answer(gt_cot)
277 | elif data_name == "gsm8k":
278 | gt_cot, gt_ans = example["answer"].split("####")
279 | elif data_name == "gsm-hard":
280 | gt_cot, gt_ans = example["code"], example["target"]
281 | elif data_name == "svamp":
282 | gt_cot, gt_ans = example["Equation"], example["Answer"]
283 | elif data_name == "asdiv":
284 | gt_cot = example["formula"]
285 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
286 | elif data_name == "mawps":
287 | gt_cot, gt_ans = None, example["target"]
288 | elif data_name == "tabmwp":
289 | gt_cot = example["solution"]
290 | gt_ans = example["answer"]
291 | if example["ans_type"] in ["integer_number", "decimal_number"]:
292 | if "/" in gt_ans:
293 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
294 | elif "," in gt_ans:
295 | gt_ans = float(gt_ans.replace(",", ""))
296 | elif "%" in gt_ans:
297 | gt_ans = float(gt_ans.split("%")[0]) / 100
298 | else:
299 | gt_ans = float(gt_ans)
300 | elif data_name == "bbh":
301 | gt_cot, gt_ans = None, example["target"]
302 | else:
303 | raise NotImplementedError(data_name)
304 | # post process
305 | gt_cot = str(gt_cot).strip()
306 | gt_ans = strip_string(gt_ans)
307 | return gt_cot, gt_ans
308 |
309 |
310 | def parse_question(example, data_name):
311 | question = ""
312 | if data_name == "asdiv":
313 | question = f"{example['body'].strip()} {example['question'].strip()}"
314 | elif data_name == "svamp":
315 | body = example["Body"].strip()
316 | if not body.endswith("."):
317 | body = body + "."
318 | question = f'{body} {example["Question"].strip()}'
319 | elif data_name == "tabmwp":
320 | title_str = f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
321 | question = f"Read the following table {title_str}and answer a question:\n"
322 | question += f'{example["table"]}\n{example["question"]}'
323 | if example["choices"]:
324 | question += f' Please select from the following options: {example["choices"]}'
325 | else:
326 | for key in ["question", "problem", "Question", "input"]:
327 | if key in example:
328 | question = example[key]
329 | break
330 | assert question != ""
331 | return question.strip()
332 |
333 |
334 | def run_execute(executor, result, prompt_type, execute=False):
335 | if not result or result == "error":
336 | return None, None
337 | report = None
338 |
339 | if "program_only" in prompt_type:
340 | prediction = extract_program_output(result)
341 | elif prompt_type in ["pot", "pal"] and execute:
342 | code = extract_program(result)
343 | prediction, report = executor.apply(code)
344 | else:
345 | prediction = extract_answer(result)
346 |
347 | prediction = strip_string(prediction)
348 | return prediction, report
349 |
--------------------------------------------------------------------------------
/inference/utils/python_executor.py:
--------------------------------------------------------------------------------
1 | ##############
2 | # Modified from ToRA and Math Instruct project.
3 | ###############
4 | import copy
5 | import datetime
6 | import io
7 | import json
8 | import os
9 | import pickle
10 | import re
11 | import traceback
12 | from concurrent.futures import TimeoutError
13 | from contextlib import redirect_stdout
14 | from functools import partial
15 | from typing import Any, Dict, Optional
16 |
17 | import dateutil.relativedelta
18 | import multiprocess
19 | import regex
20 | from multiprocess import Pool
21 | from pebble import ProcessPool
22 | from timeout_decorator import timeout
23 | from tqdm import tqdm
24 |
25 | NOT_EXECUTED = ""
26 | EXECUTION_ERROR = "Execution error:"
27 | SYNTAX_ERROR = "Syntax error:"
28 | RESULT_NOT_DEFINED_ERROR = "Result is not defined"
29 | TIMEOUT_ERROR = "timeout"
30 | UNDEFINED_ERROR = "Undefined error:"
31 | ERROR_PREFIXES = (EXECUTION_ERROR, SYNTAX_ERROR, RESULT_NOT_DEFINED_ERROR, TIMEOUT_ERROR, UNDEFINED_ERROR)
32 |
33 |
34 | def remove_ansi_escape_codes(text):
35 | """
36 | Remove ANSI escape codes from the given text.
37 |
38 | Args:
39 | text (str): The text from which to remove ANSI escape codes.
40 |
41 | Returns:
42 | str: The text with ANSI escape codes removed.
43 | """
44 | # ANSI escape codes start with the escape character followed by '['
45 | # and end with a lowercase or uppercase letter.
46 | if not text:
47 | return text
48 | ansi_escape = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
49 | return ansi_escape.sub("", text)
50 |
51 |
52 | def extract_after_traceback(text):
53 | """
54 | Extract and return the content of the text after the first occurrence of "Traceback".
55 |
56 | Args:
57 | text (str): The text from which to extract content after "Traceback".
58 |
59 | Returns:
60 | str: The content after "Traceback", or the whole text if "Traceback" is not found.
61 | """
62 | # Split the text at the first occurrence of "Traceback"
63 | parts = text.split("Traceback", 1)
64 |
65 | # Check if the split actually found "Traceback" and split the text
66 | if len(parts) > 1:
67 | # Return the part after "Traceback"
68 | return "Traceback" + parts[1]
69 | else:
70 | # Return the original text if "Traceback" is not found
71 | return text
72 |
73 |
74 | def get_error(txt):
75 | tmp = extract_after_traceback(remove_ansi_escape_codes(txt))
76 | return tmp.split("\n\n")[-1]
77 |
78 |
79 | class GenericRuntime:
80 | GLOBAL_DICT = {}
81 | LOCAL_DICT = None
82 | HEADERS = []
83 |
84 | def __init__(self):
85 | self._global_vars = copy.copy(self.GLOBAL_DICT)
86 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
87 |
88 | for c in self.HEADERS:
89 | self.exec_code(c)
90 |
91 | def exec_code(self, code_piece: str) -> None:
92 | # if the code contains input() or os.system(), return Error
93 | if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(r"(\s|^)?os.system\(", code_piece):
94 | raise RuntimeError()
95 | # exec is a built-in python function to execute python code
96 | # _global_vars is a dict containing the global variables that can be used and modified by the code_piece
97 | exec(code_piece, self._global_vars)
98 |
99 | def eval_code(self, expr: str) -> Any:
100 | """
101 | # Evaluate a simple expression
102 | result = evaluator.eval_code("3 + 4")
103 | print(result) # Output: 7
104 |
105 | # Define a variable in the global context and use it in an expression
106 | evaluator._global_vars['x'] = 10
107 | result = evaluator.eval_code("x * 2")
108 | print(result) # Output: 20
109 |
110 | # Modify a variable in the global context through evaluation
111 | evaluator.eval_code("x = x + 5")
112 | print(evaluator._global_vars['x']) # Output: 15
113 | """
114 | return eval(expr, self._global_vars)
115 |
116 | def inject(self, var_dict: Dict[str, Any]) -> None:
117 | for k, v in var_dict.items():
118 | self._global_vars[k] = v
119 |
120 | @property
121 | def answer(self):
122 | return self._global_vars["answer"]
123 |
124 |
125 | class DateRuntime(GenericRuntime):
126 | GLOBAL_DICT = {
127 | "datetime": datetime.datetime,
128 | "timedelta": dateutil.relativedelta.relativedelta,
129 | "relativedelta": dateutil.relativedelta.relativedelta,
130 | }
131 |
132 |
133 | class CustomDict(dict):
134 | def __iter__(self):
135 | return list(super().__iter__()).__iter__()
136 |
137 |
138 | class ColorObjectRuntime(GenericRuntime):
139 | GLOBAL_DICT = {"dict": CustomDict}
140 |
141 |
142 | class PythonExecutor:
143 | def __init__(
144 | self,
145 | runtime: Optional[Any] = None,
146 | get_answer_symbol: Optional[str] = None,
147 | get_answer_expr: Optional[str] = None,
148 | get_answer_from_stdout: bool = False,
149 | timeout_length: int = 20,
150 | ) -> None:
151 | self.runtime = runtime if runtime else GenericRuntime()
152 | self.answer_symbol = get_answer_symbol
153 | self.answer_expr = get_answer_expr
154 | self.get_answer_from_stdout = get_answer_from_stdout
155 | self.pool = Pool(multiprocess.cpu_count())
156 | self.timeout_length = timeout_length
157 |
158 | def process_generation_to_code(self, gens: str):
159 | return [g.split("\n") for g in gens]
160 |
161 | @staticmethod
162 | def execute(
163 | code,
164 | get_answer_from_stdout=None,
165 | runtime=None,
166 | answer_symbol=None,
167 | answer_expr=None,
168 | timeout_length=10,
169 | ):
170 | try:
171 | if get_answer_from_stdout:
172 | # io to the memory
173 | program_io = io.StringIO()
174 | # redirect_stdout: move all the standard output to the program_io
175 | with redirect_stdout(program_io):
176 | # run the code for at most timeout_length seconds and get all the output to program_io
177 | timeout(timeout_length)(runtime.exec_code)("\n".join(code))
178 | # move the the begging of the outputs
179 | program_io.seek(0)
180 | result = program_io.read()
181 | elif answer_symbol:
182 | timeout(timeout_length)(runtime.exec_code)("\n".join(code))
183 | result = runtime._global_vars[answer_symbol]
184 | elif answer_expr:
185 | timeout(timeout_length)(runtime.exec_code)("\n".join(code))
186 | # eval_code(answer_expr), possibly because the global random variables are modified and can be used..
187 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
188 | else:
189 | timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1]))
190 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
191 | report = "Done"
192 | # str(result)
193 | pickle.dumps(result) # serialization check
194 | except:
195 | report = traceback.format_exc().split("\n")[-2]
196 | result = json.dumps({"result": "", "error_message": report})
197 | return result, report
198 |
199 | def apply(self, code):
200 | return self.batch_apply([code])[0]
201 |
202 | def batch_apply(self, batch_code_seq):
203 | # We will format the codes to be executed into the Jupyter format and then run the code.
204 | # The observation is captured as the standard output of the newest code
205 | # In other words, the models can write multi-rounds of code. All the codes will be executed but only the output of the last round will be captured and returned.
206 | all_processed_codes = []
207 | for code_seq in batch_code_seq:
208 |
209 | z = """
210 | import traceback
211 | import json
212 | import os
213 | import warnings
214 | warnings.filterwarnings('ignore')
215 | os.environ['OPENBLAS_NUM_THREADS'] = '16'
216 |
217 | from IPython.core.interactiveshell import InteractiveShell
218 | from IPython.utils import io
219 | code_snippets = []
220 | """
221 | for code_snippet in code_seq:
222 | # z += f'\ncode_snippets.append("""{code_snippet}""")\n'
223 | escaped_code_snippet = code_snippet.replace('"""', '\\"\\"\\"')
224 | z += f'\ncode_snippets.append("""{escaped_code_snippet}""")\n'
225 | z += f"""
226 | try:
227 | shell = InteractiveShell()
228 | for tmp_code in code_snippets:
229 | with io.capture_output() as captured:
230 | exec_result = shell.run_cell(tmp_code)
231 | output = f"{{captured.stdout}}{{captured.stderr}}".strip().replace("Out[1]: ", "")
232 | error_message = ''
233 | if exec_result.error_in_exec is not None:
234 | error_message = f"{EXECUTION_ERROR} {{str(exec_result.error_in_exec)}}"
235 | elif exec_result.error_before_exec is not None:
236 | # full traceback will be part of output
237 | error_message = f"{SYNTAX_ERROR} {{str(exec_result.error_before_exec)}}"
238 | elif output == "":
239 | error_message = "{RESULT_NOT_DEFINED_ERROR}"
240 | to_return = {{"result": output, "error_message": error_message}}
241 | except Exception:
242 | # removing useless prefix from traceback
243 | to_return = {{
244 | "result": None,
245 | "error_message": "{UNDEFINED_ERROR}" + "\\n".join(traceback.format_exc().split("\\n")[3:]),
246 | }}
247 | print(json.dumps(to_return))
248 | """
249 | all_processed_codes.append(z)
250 | my_results = self.old_batch_apply(all_processed_codes)
251 | # Extract the old result
252 | batch_results = []
253 |
254 | for prediction in my_results:
255 | if prediction[0]:
256 | if "Timeout Error" in prediction[0]:
257 | batch_results.append(("Timeout Error", ""))
258 | continue
259 | try:
260 | dict_data = json.loads(prediction[0])
261 | except:
262 | match = re.search(r'"error_message":\s*"([^"]*)"', prediction[0])
263 | if match:
264 | batch_results.append(("", match.group(1)))
265 | else:
266 | batch_results.append(
267 | (
268 | "There exists some error in your code. Please rewrite the code and solve the problem.",
269 | "There exists some error in your code. Please rewrite the code and solve the problem.",
270 | )
271 | )
272 | continue
273 | else:
274 | batch_results.append(
275 | (
276 | "There exists some error in your code. Please rewrite the code and solve the problem.",
277 | "There exists some error in your code. Please rewrite the code and solve the problem.",
278 | )
279 | )
280 | continue
281 |
282 | try:
283 | dict_data["error_message"]
284 | except:
285 | batch_results.append(
286 | (
287 | "There exists some error in your code. Please rewrite the code and solve the problem.",
288 | "There exists some error in your code. Please rewrite the code and solve the problem.",
289 | )
290 | )
291 | continue
292 | if dict_data["error_message"]:
293 | if dict_data["result"]:
294 | batch_results.append((get_error(dict_data["result"]), ""))
295 | else:
296 | batch_results.append((dict_data["error_message"], ""))
297 | else:
298 | batch_results.append((dict_data["result"], ""))
299 | return batch_results
300 |
301 | @staticmethod
302 | def truncate(s, max_length=100):
303 | half = max_length // 2
304 | if len(s) > max_length:
305 | s = s[:half] + "..." + s[-half:]
306 | return s
307 |
308 | def old_batch_apply(self, batch_code):
309 |
310 | all_code_snippets = self.process_generation_to_code(batch_code)
311 |
312 | timeout_cnt = 0
313 | all_exec_results = []
314 | with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool:
315 | executor = partial(
316 | self.execute,
317 | get_answer_from_stdout=self.get_answer_from_stdout,
318 | runtime=self.runtime,
319 | answer_symbol=self.answer_symbol,
320 | answer_expr=self.answer_expr,
321 | timeout_length=self.timeout_length, # this timeout not work
322 | )
323 | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
324 | iterator = future.result()
325 |
326 | if len(all_code_snippets) > 100:
327 | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
328 | else:
329 | progress_bar = None
330 |
331 | while True:
332 | try:
333 | result = next(iterator)
334 | all_exec_results.append(result)
335 | except StopIteration:
336 | break
337 | except TimeoutError as error:
338 | all_exec_results.append(("Timeout Error", "Timeout Error"))
339 | timeout_cnt += 1
340 | except Exception as error:
341 | # print(error)
342 | exit()
343 | if progress_bar is not None:
344 | progress_bar.update(1)
345 |
346 | if progress_bar is not None:
347 | progress_bar.close()
348 |
349 | batch_results = []
350 | for code, (res, report) in zip(all_code_snippets, all_exec_results):
351 | # post processing
352 | res, report = str(res).strip(), str(report).strip()
353 | # res, report = self.truncate(res), self.truncate(report)
354 | batch_results.append((res.strip().replace("Out[1]: ", ""), report))
355 | return batch_results
356 |
--------------------------------------------------------------------------------
/inference/utils/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | from pathlib import Path
5 | from typing import Any, Iterable, Union
6 |
7 | import numpy as np
8 |
9 |
10 | def set_seed(seed: int = 42) -> None:
11 | np.random.seed(seed)
12 | random.seed(seed)
13 | os.environ["PYTHONHASHSEED"] = str(seed)
14 | print(f"Random seed set as {seed}")
15 |
16 |
17 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
18 | with open(file, "r", encoding="utf-8") as f:
19 | for line in f:
20 | try:
21 | yield json.loads(line)
22 | except:
23 | print("Error in loading:", line)
24 | exit()
25 |
26 |
27 | def save_jsonl(samples, save_path):
28 | # ensure path
29 | folder = os.path.dirname(save_path)
30 | os.makedirs(folder, exist_ok=True)
31 |
32 | with open(save_path, "w", encoding="utf-8") as f:
33 | for sample in samples:
34 | f.write(json.dumps(sample) + "\n")
35 | print("Saved to", save_path)
36 |
37 |
38 | def lower_keys(example):
39 | new_example = {}
40 | for key, value in example.items():
41 | if key != key.lower():
42 | new_key = key.lower()
43 | new_example[new_key] = value
44 | else:
45 | new_example[key] = value
46 | return new_example
47 |
48 |
49 | def construct_prompt(args, example):
50 | if args.prompt_type == "tora":
51 | if "gemma" in args.model_name_or_path:
52 | full_prompt = f"user\n{example['question']}\nmodel\n"
53 | elif "mistral" in args.model_name_or_path:
54 | full_prompt = f" [INST] {example['question']} [/INST]"
55 | elif "deepseek" in args.model_name_or_path:
56 | full_prompt = f"User: {example['question']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{{}}.\n\nAssistant: "
57 | elif "llama3" in args.model_name_or_path:
58 | full_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{{}}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
59 | else:
60 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path)
61 | elif args.prompt_type == "cot":
62 | if "gemma" in args.model_name_or_path:
63 | full_prompt = f"user\n{example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nmodel\n"
64 | elif "mistral" in args.model_name_or_path:
65 | full_prompt = f" [INST] {example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}. [/INST]"
66 | elif "deepseek" in args.model_name_or_path:
67 | full_prompt = f"User: {example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant: "
68 | elif "llama3" in args.model_name_or_path:
69 | full_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{example['question']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
70 | else:
71 | raise NotImplementedError(args.prompt_type + "and " + args.model_name_or_path)
72 |
73 | return full_prompt
74 |
75 |
76 | key_map = {
77 | "gt": "Ground Truth",
78 | "pred": "Prediction",
79 | "gt_cot": "Reference CoT",
80 | "score": "Score",
81 | }
82 |
83 |
84 | def show_sample(sample, print_all_preds=False):
85 | print("==" * 20)
86 | for key in ["idx", "type", "level", "dataset"]:
87 | if key in sample:
88 | # capitalize
89 | print("{}: {}".format(key[0].upper() + key[1:], sample[key]))
90 | print("Question:", repr(sample["question"]))
91 | if "code" in sample:
92 | if print_all_preds:
93 | for code in sample["code"]:
94 | print("-" * 20)
95 | print("code:", code)
96 | print("Execution:", sample["report"])
97 | else:
98 | print("Solution:\n", sample["code"][0])
99 | print("Execution:", sample["report"][0])
100 | if "pred" in sample:
101 | print("Prediction:", repr(sample["pred"][0]))
102 | for key in ["gt", "score", "unit", "gt_cot"]:
103 | if key in sample:
104 | _key = key_map.get(key, key)
105 | print("{}: {}".format(_key, repr(sample[key])))
106 | print()
107 |
--------------------------------------------------------------------------------
/useful_codes/annotate_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from collections import defaultdict
3 | import multiprocessing
4 | from math import isclose
5 | from typing import Union
6 |
7 | from datasets import Dataset, concatenate_datasets
8 |
9 | from typing import Union
10 |
11 | from sympy import N, simplify
12 | from sympy.parsing.latex import parse_latex
13 | from sympy.parsing.sympy_parser import parse_expr
14 |
15 | ds1 = load_dataset("1231czx/7b_sft_510k_3epoch_gen_data_iter1", split="train").shuffle(seed=42)
16 | ds2 = load_dataset("1231czx/7b_sft_510k_1epoch_gen_data_iter1", split="train").shuffle(seed=42)
17 |
18 | ds = concatenate_datasets([ds1, ds2]) # .select(range(100000))
19 | # ds = ds.select(range(5000))
20 | N_pair = 1
21 | data_comp = "1231czx/7B_iter1_dpo_N1_random_pair"
22 | data_sft = "1231czx/7B_iter1_sft_N1"
23 |
24 |
25 |
26 | def symbolic_equal(a, b):
27 | def _parse(s):
28 | for f in [parse_latex, parse_expr]:
29 | try:
30 | return f(s)
31 | except:
32 | pass
33 | return s
34 |
35 | a = _parse(a)
36 | b = _parse(b)
37 |
38 | try:
39 | if simplify(a - b) == 0:
40 | return True
41 | except:
42 | pass
43 |
44 | try:
45 | if isclose(N(a), N(b), rel_tol=1e-3):
46 | return True
47 | except:
48 | pass
49 | return False
50 |
51 |
52 | def symbolic_equal_process(a, b, output_queue):
53 | result = symbolic_equal(a, b)
54 | output_queue.put(result)
55 |
56 |
57 | def call_with_timeout(func, *args, timeout=10, **kwargs):
58 | output_queue = multiprocessing.Queue()
59 | process_args = args + (output_queue,)
60 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
61 | process.start()
62 | process.join(timeout)
63 |
64 | if process.is_alive():
65 | process.terminate()
66 | process.join()
67 | # return "time out"
68 | return False
69 | return output_queue.get()
70 |
71 |
72 | def math_equal(
73 | prediction: Union[bool, float, str],
74 | reference: Union[float, str],
75 | include_percentage: bool = True,
76 | is_close: bool = True,
77 | timeout: bool = True,
78 | ) -> bool:
79 | """
80 | Exact match of math if and only if:
81 | 1. numerical equal: both can convert to float and are equal
82 | 2. symbolic equal: both can convert to sympy expression and are equal
83 | """
84 | try: # 1. numerical equal
85 | if is_digit(prediction) and is_digit(reference):
86 | prediction = float(str(prediction).replace(",", ""))
87 | reference = float(str(reference).replace(",", ""))
88 | # number questions
89 | if include_percentage:
90 | gt_result = [reference / 100, reference, reference * 100]
91 | else:
92 | gt_result = [reference]
93 | for item in gt_result:
94 | try:
95 | if is_close:
96 | if isclose(item, prediction, rel_tol=1e-4):
97 | return True
98 | else:
99 | if item == prediction:
100 | return True
101 | except Exception:
102 | continue
103 | return False
104 | except:
105 | pass
106 | if not prediction and prediction not in [0, False]:
107 | return False
108 |
109 | # 2. symbolic equal
110 | reference = str(reference).strip()
111 | prediction = str(prediction).strip()
112 |
113 | ## deal with [], (), {}
114 | pred_str, ref_str = prediction, reference
115 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
116 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
117 | ):
118 | pred_str = pred_str.strip("[]()")
119 | ref_str = ref_str.strip("[]()")
120 | for s in ["{", "}", "(", ")"]:
121 | ref_str = ref_str.replace(s, "")
122 | pred_str = pred_str.replace(s, "")
123 | if pred_str == ref_str:
124 | return True
125 |
126 | ## [a, b] vs. [c, d], return a==c and b==d
127 | if (
128 | (prediction.startswith("[") and prediction.endswith("]"))
129 | and (reference.startswith("[") and reference.endswith("]"))
130 | or (prediction.startswith("(") and prediction.endswith(")"))
131 | and (reference.startswith("(") and reference.endswith(")"))
132 | ):
133 | pred_parts = prediction[1:-1].split(",")
134 | ref_parts = reference[1:-1].split(",")
135 | if len(pred_parts) == len(ref_parts):
136 | if all(
137 | [math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]
138 | ):
139 | return True
140 |
141 | # symbolic equal with sympy
142 | if timeout:
143 | if call_with_timeout(symbolic_equal_process, prediction, reference):
144 | return True
145 | # elif tmp == 'time out':
146 | # return "time out"
147 | else:
148 | if symbolic_equal(prediction, reference):
149 | return True
150 |
151 | return False
152 |
153 |
154 | def _fix_a_slash_b(string):
155 | if len(string.split("/")) != 2:
156 | return string
157 | a = string.split("/")[0]
158 | b = string.split("/")[1]
159 | try:
160 | if "sqrt" not in a:
161 | a = int(a)
162 | if "sqrt" not in b:
163 | b = int(b)
164 | assert string == "{}/{}".format(a, b)
165 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
166 | return new_string
167 | except:
168 | return string
169 |
170 |
171 | def _fix_fracs(string):
172 | substrs = string.split("\\frac")
173 | new_str = substrs[0]
174 | if len(substrs) > 1:
175 | substrs = substrs[1:]
176 | for substr in substrs:
177 | new_str += "\\frac"
178 | if len(substr) > 0 and substr[0] == "{":
179 | new_str += substr
180 | else:
181 | try:
182 | assert len(substr) >= 2
183 | except:
184 | return string
185 | a = substr[0]
186 | b = substr[1]
187 | if b != "{":
188 | if len(substr) > 2:
189 | post_substr = substr[2:]
190 | new_str += "{" + a + "}{" + b + "}" + post_substr
191 | else:
192 | new_str += "{" + a + "}{" + b + "}"
193 | else:
194 | if len(substr) > 2:
195 | post_substr = substr[2:]
196 | new_str += "{" + a + "}" + b + post_substr
197 | else:
198 | new_str += "{" + a + "}" + b
199 | string = new_str
200 | return string
201 |
202 |
203 | def _fix_sqrt(string):
204 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
205 | return _string
206 |
207 |
208 | def strip_string(string):
209 | string = str(string).strip()
210 | # linebreaks
211 | string = string.replace("\n", "")
212 |
213 | # right "."
214 | string = string.rstrip(".")
215 |
216 | # remove inverse spaces
217 | string = string.replace("\\!", "")
218 | string = string.replace("\\ ", "")
219 |
220 | # replace \\ with \
221 | string = string.replace("\\\\", "\\")
222 | string = string.replace("\\\\", "\\")
223 |
224 | # replace tfrac and dfrac with frac
225 | string = string.replace("tfrac", "frac")
226 | string = string.replace("dfrac", "frac")
227 |
228 | # remove \left and \right
229 | string = string.replace("\\left", "")
230 | string = string.replace("\\right", "")
231 |
232 | # Remove unit: miles, dollars if after is not none
233 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
234 | if _string != "" and _string != string:
235 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
236 | string = _string
237 |
238 | # Remove circ (degrees)
239 | string = string.replace("^{\\circ}", "")
240 | string = string.replace("^\\circ", "")
241 |
242 | # remove dollar signs
243 | string = string.replace("\\$", "")
244 | string = string.replace("$", "")
245 |
246 | string = string.replace("\\text", "")
247 | string = string.replace("x\\in", "")
248 |
249 | # remove percentage
250 | string = string.replace("\\%", "")
251 | string = string.replace("\%", "")
252 | string = string.replace("%", "")
253 |
254 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
255 | string = string.replace(" .", " 0.")
256 | string = string.replace("{.", "{0.")
257 |
258 | # cdot
259 | string = string.replace("\\cdot", "")
260 |
261 | # inf
262 | string = string.replace("infinity", "\\infty")
263 | if "\\infty" not in string:
264 | string = string.replace("inf", "\\infty")
265 | string = string.replace("+\\inity", "\\infty")
266 |
267 | # and
268 | string = string.replace("and", "")
269 | string = string.replace("\\mathbf", "")
270 |
271 | # use regex to remove \mbox{...}
272 | string = re.sub(r"\\mbox{.*?}", "", string)
273 |
274 | # quote
275 | string.replace("'", "")
276 | string.replace('"', "")
277 |
278 | # i, j
279 | if "j" in string and "i" not in string:
280 | string = string.replace("j", "i")
281 |
282 | # replace a.000b where b is not number or b is end, with ab, use regex
283 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
284 | string = re.sub(r"(\d+)\.0+$", r"\1", string)
285 |
286 | # if empty, return empty string
287 | if len(string) == 0:
288 | return string
289 | if string[0] == ".":
290 | string = "0" + string
291 |
292 | # to consider: get rid of e.g. "k = " or "q = " at beginning
293 | if len(string.split("=")) == 2:
294 | if len(string.split("=")[0]) <= 2:
295 | string = string.split("=")[1]
296 |
297 | string = _fix_sqrt(string)
298 | string = string.replace(" ", "")
299 |
300 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
301 | string = _fix_fracs(string)
302 |
303 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
304 | string = _fix_a_slash_b(string)
305 |
306 | return string
307 |
308 |
309 | import re
310 |
311 |
312 | def check1(a, b):
313 |
314 | try:
315 | a = parse_latex(a)
316 | b = parse_latex(b)
317 | if abs(float(a) - float(b)) < 0.01:
318 | return True
319 | except:
320 | pass
321 |
322 | return False
323 |
324 |
325 | print(check1("0.25", "\\frac{1}{4}"))
326 |
327 |
328 | def parse_ground_truth(example):
329 | cnt = 0
330 | if example["type"] in [
331 | "gpt-3.5-turbo",
332 | "MATH_FOBAR",
333 | "GSM_Rephrased",
334 | "MATH_SV",
335 | "MATH_Rephrased",
336 | "GSM_SV",
337 | "GSM_FOBAR",
338 | ]:
339 | pattern = r"The answer is: (.*?)$"
340 | match = re.search(pattern, example["solution"])
341 |
342 | if match:
343 | gt_ans = match.group(1)
344 | # print("The prediction is:", prediction)
345 | else:
346 | print("No prediction found for gpt-3.5-turbo.")
347 |
348 | check_ans = extract_answer(example["solution"])
349 | if check_ans != gt_ans:
350 | # if math_equal(gt_ans, check_ans):
351 | # pass
352 | if check1(gt_ans, check_ans):
353 | pass
354 | else:
355 | # print(example['type'], gt_ans, check_ans, "\n")
356 | cnt += 1
357 | gt_ans = "delete"
358 | elif example["type"] == "gsm8k":
359 | gt_ans = example["solution"].split("####")[-1]
360 |
361 | elif example["type"] == "math":
362 | gt_ans = extract_answer(example["solution"])
363 |
364 | else:
365 | print(example["type"])
366 | raise NotImplementedError()
367 | gt_ans = strip_string(gt_ans)
368 | # if gt_ans.startswith('\frac'):
369 | # gt_ans = '\' + gt_ans
370 | return {"gt": gt_ans}
371 |
372 |
373 | def extract_answer(pred_str):
374 | if "boxed" in pred_str:
375 | ans = pred_str.split("boxed")[-1]
376 | if len(ans) == 0:
377 | return ""
378 | elif ans[0] == "{":
379 | stack = 1
380 | a = ""
381 | for c in ans[1:]:
382 | if c == "{":
383 | stack += 1
384 | a += c
385 | elif c == "}":
386 | stack -= 1
387 | if stack == 0:
388 | break
389 | a += c
390 | else:
391 | a += c
392 | else:
393 | a = ans.split("$")[0].strip()
394 | pred = a
395 | elif "he answer is" in pred_str:
396 | pred = pred_str.split("he answer is")[-1].strip()
397 | elif extract_program_output(pred_str) != "":
398 | # fall back to program
399 | pred = extract_program_output(pred_str)
400 | else: # use the last number
401 | pattern = "-?\d*\.?\d+"
402 | pred = re.findall(pattern, pred_str.replace(",", ""))
403 | if len(pred) >= 1:
404 | pred = pred[-1]
405 | else:
406 | pred = ""
407 |
408 | # multiple line
409 | pred = pred.split("\n")[0]
410 | if pred != "" and pred[0] == ":":
411 | pred = pred[1:]
412 | if pred != "" and pred[-1] == ".":
413 | pred = pred[:-1]
414 | if pred != "" and pred[-1] == "/":
415 | pred = pred[:-1]
416 | pred = strip_string(pred)
417 | return pred
418 |
419 |
420 | from sympy.parsing.latex import parse_latex
421 |
422 | # Example LaTeX string
423 |
424 | z = 0
425 | # ds = load_dataset('1231czx/7b_sft_510k_3epoch_gen_data_iter1', split='train')
426 |
427 | # for sample in ds:
428 | # a, b = parse_ground_truth(sample)
429 | # z += b
430 | # print(z)
431 |
432 | ds_new = ds.map(parse_ground_truth, num_proc=32)
433 | ds_new = ds_new.filter(lambda example: example["gt"] != "delete")
434 |
435 | print("#######################################\n", "After delete the prompts that cannot be verified")
436 | print(ds_new[0])
437 | print(ds, ds_new)
438 |
439 |
440 | ###########################################
441 |
442 | import re
443 |
444 |
445 | def parse_conversation(example):
446 | # Split the data into turns based on the start_of_turn and end_of_turn markers
447 | data = example["my_solu"][0]
448 | turns = re.split(r"|", data)
449 |
450 | # Clean and filter out empty entries
451 | turns = [turn.strip() for turn in turns if turn.strip() and not turn.startswith("")]
452 |
453 | # Create a list to hold the parsed conversation in the desired format
454 | conversation = []
455 |
456 | # Process each turn, assigning the correct role and content
457 | for turn in turns:
458 | if turn.startswith("user\n"):
459 | # Extract content after the role identifier
460 | content = turn[5:].strip()
461 | conversation.append({"role": "user", "content": content})
462 | elif turn.startswith("model\n"):
463 | content = turn[6:].strip()
464 | conversation.append({"role": "assistant", "content": content})
465 |
466 | return {"messages": conversation}
467 |
468 |
469 | ds_new = ds_new.map(parse_conversation, num_proc=32)
470 |
471 |
472 |
473 | def is_correct(example):
474 | is_correct_label = False
475 | a = example["pred"][0]
476 | b = example["gt"]
477 | if a == b:
478 | is_correct_label = True
479 | return {"is_correct": is_correct_label}
480 | try:
481 | if abs(float(a) - float(b)) < 0.0001:
482 | is_correct_label = True
483 | return {"is_correct": is_correct_label}
484 | except:
485 | pass
486 |
487 | try:
488 |
489 | if math_equal(a, b):
490 | # if z == True:
491 | is_correct_label = True
492 |
493 | return {"is_correct": is_correct_label}
494 |
495 | # elif z == 'time out':
496 | # print("time out")
497 | # return {"is_correct": 'time out'}
498 |
499 | # signal.alarm(0)
500 | # except TimeoutException:
501 | # pass
502 | # if math_equal(a, b):
503 | # is_correct_label = True
504 | except:
505 | pass
506 |
507 | """
508 | try:
509 | if check1(a, b):
510 | is_correct_label = True
511 | return {"is_correct": is_correct_label}
512 | except:
513 | pass
514 | """
515 | # try:
516 | # if math_equal(a, b):
517 | # is_correct_label = True
518 | # except:
519 | # pass
520 |
521 | # print(example['type'])
522 | # raise NotImplementedError()
523 | # gt_ans = strip_string(gt_ans)
524 | return {"is_correct": is_correct_label}
525 |
526 |
527 | def filter_example1(example):
528 | old_messages = example["messages"]
529 |
530 | if len(old_messages) < 4:
531 | return False
532 |
533 | if len(old_messages) % 2 != 0:
534 | return False
535 |
536 | all_mes_len = len(old_messages)
537 | if example["is_correct"] and "error" in old_messages[-2]["content"].lower():
538 | return False
539 |
540 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower():
541 | return False
542 |
543 | k = 0
544 |
545 | for mes in old_messages:
546 | if k % 2 != 0 and k < all_mes_len - 1:
547 | if "python" not in mes["content"]:
548 | return False
549 | k += 1
550 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower():
551 | return False
552 | if mes["content"] == "```output\nExecution error: \n```":
553 | return False
554 | if "```output\n[]" in mes["content"]:
555 | # print(mes['content'])
556 | return False
557 |
558 | return True
559 |
560 |
561 | from transformers import AutoTokenizer, HfArgumentParser
562 |
563 | # ds_new = ds_new.filter(filter_example1, num_proc=32)
564 |
565 | print("############### Answer check")
566 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")
567 |
568 |
569 |
570 | def filter_too_long_pred(example):
571 | try:
572 | if len(example["pred"][0]) > 20:
573 | return False
574 | except:
575 | return False
576 | old_messages = example["messages"]
577 | if len(old_messages) < 4:
578 | return False
579 |
580 | if len(old_messages) % 2 != 0:
581 | return False
582 |
583 | all_mes_len = len(old_messages)
584 |
585 | if "boxed" in old_messages[-1]["content"].lower() and "error" in old_messages[-2]["content"].lower():
586 | return False
587 |
588 | k = 0
589 |
590 | for mes in old_messages:
591 | if k % 2 != 0 and k < all_mes_len - 1:
592 | if "python" not in mes["content"]:
593 | return False
594 | k += 1
595 | if "ipython" in mes["content"].lower() and "error" in mes["content"].lower():
596 | return False
597 | if mes["content"] == "```output\nExecution error: \n```":
598 | return False
599 | if "```output\n[]" in mes["content"]:
600 | # print(mes['content'])
601 | return False
602 |
603 | z = len(tokenizer.apply_chat_template(old_messages, tokenize=True))
604 | if z > 2048:
605 | return False
606 |
607 | # b = len(tokenizer.apply_chat_template(ds_test[i]['rejected'], tokenize=True))
608 |
609 | return True
610 |
611 |
612 | ds_new = ds_new.filter(filter_too_long_pred, num_proc=32)
613 |
614 | ds_new = ds_new.map(is_correct, num_proc=32)
615 |
616 |
617 | ds_new = ds_new.filter(filter_example1, num_proc=32)
618 |
619 | ds_win = ds_new.filter(lambda example: example["is_correct"] == True)
620 | ds_lose = ds_new.filter(lambda example: example["is_correct"] == False)
621 |
622 | print(len(ds_win) + len(ds_lose), len(ds_new))
623 |
624 |
625 | # print("I have win ", len(ds_win), " and lose ", len(ds_lose))
626 | def filter_win(example):
627 | old_messages = example["messages"]
628 |
629 | if "error" in old_messages[-2]["content"].lower():
630 | return False
631 | if "none" in old_messages[-2]["content"].lower():
632 | return False
633 |
634 | return True
635 |
636 |
637 | print("I have win ", len(ds_win), " and lose ", len(ds_lose))
638 | ds_win = ds_win.filter(filter_win, num_proc=32)
639 | print("I have win ", len(ds_win), " and lose ", len(ds_lose))
640 |
641 |
642 | win_ret = defaultdict(list)
643 | lose_ret = defaultdict(list)
644 |
645 |
646 | for sample in ds_win:
647 | idx = sample["idx"]
648 | win_ret[idx].append(sample)
649 |
650 |
651 | for sample in ds_lose:
652 | idx = sample["idx"]
653 | lose_ret[idx].append(sample)
654 |
655 |
656 | new_win_ret = defaultdict(list)
657 |
658 | cnt_win = 0
659 |
660 | for key, value in win_ret.items():
661 | if len(win_ret[key]) == 0:
662 | continue
663 | j = key
664 | all_samples = win_ret[j]
665 | all_texts = []
666 | new_samples = []
667 | for ins in all_samples:
668 | if ins["messages"][1]["content"] in all_texts:
669 | continue
670 | all_texts.append(ins["messages"][1]["content"])
671 | new_samples.append(ins)
672 | cnt_win += 1
673 |
674 | new_win_ret[j].extend(new_samples)
675 |
676 |
677 | cnt_lose = 0
678 | new_lose_ret = defaultdict(list)
679 |
680 |
681 | for key, value in lose_ret.items():
682 | if len(lose_ret[key]) == 0:
683 | continue
684 | j = key
685 | all_samples = lose_ret[j]
686 | all_texts = []
687 | new_samples = []
688 | for ins in all_samples:
689 | if ins["messages"][1]["content"] in all_texts:
690 | continue
691 | all_texts.append(ins["messages"][1]["content"])
692 | new_samples.append(ins)
693 | cnt_lose += 1
694 |
695 | new_lose_ret[j].extend(new_samples)
696 |
697 | print("Before get final pairs, I have win and lose", cnt_win, cnt_lose)
698 | import random
699 |
700 | import numpy as np
701 |
702 | all_comp = []
703 | all_sft = []
704 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys())
705 | all_keys = list(set(all_keys))
706 | import itertools
707 |
708 | """
709 | for ins in new_win_ret[0]:
710 | print(is_correct(ins), ins['pred'][0])
711 |
712 | for ins in new_lose_ret[0]:
713 | print(is_correct(ins), ins['pred'][0], ins)
714 | print(check1(ins['gt'], ins['pred'][0]))
715 | print(check1('\\frac{1}{4}', '0.25'))
716 | """
717 | cnt_comp = 0
718 | # N = 1
719 | for j in all_keys:
720 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0:
721 | cnt_comp += N_pair
722 | else:
723 | continue
724 | all_pos = new_win_ret[j]
725 | all_neg = new_lose_ret[j]
726 | random.shuffle(all_pos)
727 | random.shuffle(all_neg)
728 | if len(all_pos) > N_pair and len(all_neg) > N_pair:
729 | for k in range(N_pair):
730 | all_comp.append(
731 | {
732 | "gt": all_pos[k]["gt"],
733 | "rej": all_neg[k]["pred"],
734 | "chosen": all_pos[k]["messages"],
735 | "rejected": all_neg[k]["messages"],
736 | }
737 | )
738 | all_sft.append({"messages": all_pos[k]["messages"]})
739 | continue
740 |
741 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg)))))
742 |
743 | random.shuffle(combinations)
744 | for k in range(np.min([len(combinations), N_pair])):
745 | all_comp.append(
746 | {
747 | "gt": all_pos[combinations[k][0]]["gt"],
748 | "chosen": all_pos[combinations[k][0]]["messages"],
749 | "rejected": all_neg[combinations[k][1]]["messages"],
750 | "rej": all_neg[combinations[k][1]]["pred"],
751 | }
752 | )
753 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]})
754 |
755 |
756 | # print(all_comp[0])
757 | output_eval_dataset = {}
758 | output_eval_dataset["type"] = "text_only"
759 | output_eval_dataset["instances"] = all_comp
760 | print("I collect ", len(all_comp), "samples", len(all_sft))
761 |
762 | import json
763 |
764 | with open("tmp_comp.json", "w", encoding="utf8") as f:
765 | json.dump(output_eval_dataset, f, ensure_ascii=False)
766 |
767 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances")
768 | ds_comp.push_to_hub(data_comp)
769 |
770 | output_eval_dataset = {}
771 | output_eval_dataset["type"] = "text_only"
772 | output_eval_dataset["instances"] = all_sft
773 | print("I collect ", len(all_comp), "samples", len(all_sft))
774 |
775 | import json
776 |
777 | with open("tmp_sft.json", "w", encoding="utf8") as f:
778 | json.dump(output_eval_dataset, f, ensure_ascii=False)
779 |
780 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances")
781 | ds_sft.push_to_hub(data_sft)
782 |
783 | # ds_new = ds_new.remove_columns("idx")
784 |
785 | # def add_index(example, idx):
786 | # # Add the current index to the example under a new field 'index'
787 | # example['idx'] = idx
788 | # return example
789 |
790 | # ds_new = ds_new.map(add_index, with_indices=True)
791 | # ds_new.push_to_hub("1231czx/prompts_80K_with_original_MATH_GSM8K_iter1")
792 |
793 | ###################################
794 |
795 | N_pair = 3
796 | data_comp = "1231czx/7B_iter1_dpo_N3_random_pair"
797 | data_sft = "1231czx/7B_iter1_sft_N3"
798 |
799 | all_comp = []
800 | all_sft = []
801 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys())
802 | all_keys = list(set(all_keys))
803 | import itertools
804 |
805 | """
806 | for ins in new_win_ret[0]:
807 | print(is_correct(ins), ins['pred'][0])
808 |
809 | for ins in new_lose_ret[0]:
810 | print(is_correct(ins), ins['pred'][0], ins)
811 | print(check1(ins['gt'], ins['pred'][0]))
812 | print(check1('\\frac{1}{4}', '0.25'))
813 | """
814 | cnt_comp = 0
815 | # N = 1
816 | for j in all_keys:
817 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0:
818 | cnt_comp += N_pair
819 | else:
820 | continue
821 | all_pos = new_win_ret[j]
822 | all_neg = new_lose_ret[j]
823 | random.shuffle(all_pos)
824 | random.shuffle(all_neg)
825 | if len(all_pos) > N_pair and len(all_neg) > N_pair:
826 | for k in range(N_pair):
827 | all_comp.append(
828 | {
829 | "gt": all_pos[k]["gt"],
830 | "rej": all_neg[k]["pred"],
831 | "chosen": all_pos[k]["messages"],
832 | "rejected": all_neg[k]["messages"],
833 | }
834 | )
835 | all_sft.append({"messages": all_pos[k]["messages"]})
836 | continue
837 |
838 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg)))))
839 |
840 | random.shuffle(combinations)
841 | for k in range(np.min([len(combinations), N_pair])):
842 | all_comp.append(
843 | {
844 | "gt": all_pos[combinations[k][0]]["gt"],
845 | "chosen": all_pos[combinations[k][0]]["messages"],
846 | "rejected": all_neg[combinations[k][1]]["messages"],
847 | "rej": all_neg[combinations[k][1]]["pred"],
848 | }
849 | )
850 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]})
851 |
852 |
853 | # print(all_comp[0])
854 | output_eval_dataset = {}
855 | output_eval_dataset["type"] = "text_only"
856 | output_eval_dataset["instances"] = all_comp
857 | print("I collect ", len(all_comp), "samples", len(all_sft))
858 |
859 | import json
860 |
861 | with open("tmp_comp.json", "w", encoding="utf8") as f:
862 | json.dump(output_eval_dataset, f, ensure_ascii=False)
863 |
864 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances")
865 | ds_comp.push_to_hub(data_comp)
866 |
867 | output_eval_dataset = {}
868 | output_eval_dataset["type"] = "text_only"
869 | output_eval_dataset["instances"] = all_sft
870 | print("I collect ", len(all_comp), "samples", len(all_sft))
871 |
872 | import json
873 |
874 | with open("tmp_sft.json", "w", encoding="utf8") as f:
875 | json.dump(output_eval_dataset, f, ensure_ascii=False)
876 |
877 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances")
878 | ds_sft.push_to_hub(data_sft)
879 |
880 |
881 | ##########################
882 |
883 | N_pair = 8
884 | data_comp = "1231czx/7B_iter1_dpo_N8_random_pair"
885 | data_sft = "1231czx/7B_iter1_sft_N8"
886 |
887 | all_comp = []
888 | all_sft = []
889 | all_keys = list(new_lose_ret.keys()) + list(new_win_ret.keys())
890 | all_keys = list(set(all_keys))
891 | import itertools
892 |
893 | """
894 | for ins in new_win_ret[0]:
895 | print(is_correct(ins), ins['pred'][0])
896 |
897 | for ins in new_lose_ret[0]:
898 | print(is_correct(ins), ins['pred'][0], ins)
899 | print(check1(ins['gt'], ins['pred'][0]))
900 | print(check1('\\frac{1}{4}', '0.25'))
901 | """
902 | cnt_comp = 0
903 | # N = 1
904 | for j in all_keys:
905 | if len(new_lose_ret[j]) > 0 and len(new_win_ret[j]) > 0:
906 | cnt_comp += N_pair
907 | else:
908 | continue
909 | all_pos = new_win_ret[j]
910 | all_neg = new_lose_ret[j]
911 | random.shuffle(all_pos)
912 | random.shuffle(all_neg)
913 | if len(all_pos) > N_pair and len(all_neg) > N_pair:
914 | for k in range(N_pair):
915 | all_comp.append(
916 | {
917 | "gt": all_pos[k]["gt"],
918 | "rej": all_neg[k]["pred"],
919 | "chosen": all_pos[k]["messages"],
920 | "rejected": all_neg[k]["messages"],
921 | }
922 | )
923 | all_sft.append({"messages": all_pos[k]["messages"]})
924 | continue
925 |
926 | combinations = list(itertools.product(list(range(len(all_pos))), list(range(len(all_neg)))))
927 |
928 | random.shuffle(combinations)
929 | for k in range(np.min([len(combinations), N_pair])):
930 | all_comp.append(
931 | {
932 | "gt": all_pos[combinations[k][0]]["gt"],
933 | "chosen": all_pos[combinations[k][0]]["messages"],
934 | "rejected": all_neg[combinations[k][1]]["messages"],
935 | "rej": all_neg[combinations[k][1]]["pred"],
936 | }
937 | )
938 | all_sft.append({"messages": all_pos[combinations[k][0]]["messages"]})
939 |
940 |
941 | # print(all_comp[0])
942 | output_eval_dataset = {}
943 | output_eval_dataset["type"] = "text_only"
944 | output_eval_dataset["instances"] = all_comp
945 | print("I collect ", len(all_comp), "samples", len(all_sft))
946 |
947 | import json
948 |
949 | with open("tmp_comp.json", "w", encoding="utf8") as f:
950 | json.dump(output_eval_dataset, f, ensure_ascii=False)
951 |
952 | ds_comp = load_dataset("json", data_files="tmp_comp.json", split="train", field="instances")
953 | ds_comp.push_to_hub(data_comp)
954 |
955 | output_eval_dataset = {}
956 | output_eval_dataset["type"] = "text_only"
957 | output_eval_dataset["instances"] = all_sft
958 | print("I collect ", len(all_comp), "samples", len(all_sft))
959 |
960 | import json
961 |
962 | with open("tmp_sft.json", "w", encoding="utf8") as f:
963 | json.dump(output_eval_dataset, f, ensure_ascii=False)
964 |
965 | ds_sft = load_dataset("json", data_files="tmp_sft.json", split="train", field="instances")
966 | ds_sft.push_to_hub(data_sft)
967 |
--------------------------------------------------------------------------------
/useful_codes/interpolate_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | #!/usr/bin/env python
3 | # coding=utf-8
4 | # We provide a script to merge different checkpoints of the model
5 | # See a detailed study in Mitigating the Alignment Tax of RLHF, https://arxiv.org/abs/2309.06256
6 |
7 | import json
8 | import os
9 | import sys
10 | from transformers import HfArgumentParser, AutoModelForCausalLM
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description="merge model checkpoints")
14 | parser.add_argument("--base_model", type=str, required=True, help="dir of base model")
15 | parser.add_argument("--new_model", type=str, required=True, help="dir of new model")
16 | parser.add_argument("--output_dir", type=str, required=True, help="output dir")
17 | parser.add_argument("--ratio", type=float, required=True, help="the ratio of the new model")
18 |
19 | args = parser.parse_args()
20 | os.makedirs(args.output_dir, exist_ok=True)
21 | print(f"Base model: {args.base_model}")
22 | print(f"New model: {args.new_model}")
23 | print(f"Output directory: {args.output_dir}")
24 |
25 | new_dir = args.new_model
26 | base_dir = args.base_model
27 | weight_ensamble_save_path = args.output_dir
28 | weight_ensamble_ratios = args.ratio
29 |
30 | # Get the paths and ratios of weight-ensamble models.
31 | # args.ratio * new_model + (1 - args.ratio) * base_model
32 |
33 | weight_ensamble_names_paths = [new_dir, base_dir]
34 | weight_ensamble_ratios.append(1 - weight_ensamble_ratios[0])
35 | assert len(weight_ensamble_ratios) == 2, 'Only 2 merge is supported.'
36 | print('Model Paths:', weight_ensamble_names_paths)
37 | print('Model Ratio:', weight_ensamble_ratios)
38 |
39 | base_model = None
40 | backend_models = []
41 | for model_path in weight_ensamble_names_paths:
42 | #model_args.model_name_or_path = model_path
43 | print('loading:', model_path)
44 | model = AutoModelForCausalLM.from_pretrained(model_path)#, torch_dtype=torch.bfloat16)
45 | backend_models.append(model.to('cpu'))
46 | if base_model is None:
47 | base_model = model
48 | print('Finish load:', model_path)
49 | base_backend_model = backend_models[0]
50 | print('Finish load All:', base_backend_model)
51 |
52 | updated_state_dict = {}
53 | for key in base_backend_model.state_dict():
54 | ensambled_state_dicts = [ratio * backend_model.state_dict()[key] for backend_model, ratio in zip(backend_models, weight_ensamble_ratios)]
55 | updated_state_dict[key] = sum(ensambled_state_dicts)
56 |
57 | base_backend_model.load_state_dict(updated_state_dict)
58 | base_model.save_pretrained(weight_ensamble_save_path)
59 |
--------------------------------------------------------------------------------
/useful_codes/merge.py:
--------------------------------------------------------------------------------
1 | # This script finds all the files ending with jsonl and merge them into one file
2 | import os
3 | from datasets import load_dataset
4 | import json
5 |
6 | # The folders to load data
7 | all_folder_path = [
8 | './7b_sft3epoch_gen_data/iter1',
9 | './7b_sft1epoch_gen_data/iter1'
10 | ]
11 |
12 | output_dir='all_math.json'
13 |
14 | all_data = []
15 | for folder_path in all_folder_path:
16 | jsonl_files = [folder_path + '/' + f for f in os.listdir(folder_path) if f.endswith('.jsonl')]
17 | for dir_ in jsonl_files:
18 | ds_test = load_dataset('json', data_files=dir_, split='train')
19 | for sample in ds_test:
20 | all_data.append(sample)
21 |
22 | output_eval_dataset = {}
23 | output_eval_dataset["type"] = "text_only"
24 | output_eval_dataset["instances"] = all_data
25 | print("I collect ", len(all_data), "samples")
26 |
27 | with open(output_dir, "w", encoding="utf8") as f:
28 | json.dump(output_eval_dataset, f, ensure_ascii=False)
29 |
30 | # You can also upload the dataset to the huggingface, which is particularly useful when collecting data with multiple machines
31 | """
32 | output_dir = "xxx"
33 | dict_data = {
34 | "idx": [d['idx'] for d in all_data],
35 | "gt": [d['gt'] for d in all_data],
36 | "level": [d['level'] for d in all_data],
37 | "type": [d['type'] for d in all_data],
38 | "my_solu": [d['my_solu'] for d in all_data],
39 | "pred": [d['pred'] for d in all_data],
40 | }
41 |
42 | dataset = Dataset.from_dict(dict_data)
43 | DatasetDict({'train': dataset}).push_to_hub(output_dir)
44 | """
45 |
--------------------------------------------------------------------------------
/useful_codes/push_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 |
4 | name = 'model_dir'
5 | output_name = 'output_name'
6 | tokenizer_name = name
7 |
8 | model = AutoModelForCausalLM.from_pretrained(
9 | name,
10 | torch_dtype=torch.bfloat16,
11 | )
12 |
13 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
14 |
15 | model.push_to_hub(output_name)
16 | tokenizer.push_to_hub(output_name)
17 |
--------------------------------------------------------------------------------
/useful_codes/set_padding_token.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 |
4 | name = 'mistralai/Mistral-7B-v0.3'
5 | tokenizer_name = name
6 |
7 | model = AutoModelForCausalLM.from_pretrained(
8 | name,
9 | torch_dtype=torch.bfloat16,
10 | )
11 |
12 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
13 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
14 | model.config.pad_token_id = tokenizer.pad_token_id
15 |
16 | model.resize_token_embeddings(len(tokenizer))
17 |
18 | model.save_pretrained("output_dir")
19 | tokenizer.save_pretrained("output_dir")
20 |
--------------------------------------------------------------------------------