├── .env ├── alignment-handbook ├── tests │ ├── __init__.py │ ├── fixtures │ │ ├── config_dpo_full.yaml │ │ └── config_sft_full.yaml │ ├── test_configs.py │ ├── test_model_utils.py │ └── test_data.py ├── README.md ├── chapters │ └── en │ │ ├── chapter0 │ │ └── introduction.mdx │ │ └── _toctree.yml ├── assets │ └── handbook.png ├── scripts │ ├── setup.sh │ ├── test_gamma_bird.sh │ ├── train_spider.sh │ └── train_bird.sh ├── src │ └── alignment │ │ ├── __init__.py │ │ ├── release.py │ │ └── model_utils.py ├── recipes │ ├── accelerate_configs │ │ ├── multi_gpu.yaml │ │ ├── zero2.yaml │ │ ├── deepspeed_dpo.yaml │ │ ├── deepspeed_zero3.yaml │ │ ├── no_offload_optimizer.yaml │ │ └── deepspeed_original.yaml │ ├── llama-1b-bird │ │ ├── fixed-fft.yaml │ │ ├── selection-fft.yaml │ │ ├── validator-fixer-fft.yaml │ │ ├── orpo-fixed.yaml │ │ ├── orpo-validator-fixed.yaml │ │ ├── llama-3-bird-planner-fft.yaml │ │ ├── validator-fft.yaml │ │ ├── llama-3-bird-validator-fft.yaml │ │ └── orpo-validator.yaml │ ├── llama-3b-bird │ │ ├── selection-fft.yaml │ │ ├── validator-fixer-fft.yaml │ │ ├── orpo-selection.yaml │ │ ├── orpo-fixed.yaml │ │ ├── orpo-planner-iter-2.yaml │ │ ├── orpo-planner-iter-3.yaml │ │ ├── orpo-fixed-iter-2.yaml │ │ ├── orpo-planner.yaml │ │ ├── planner-fft.yaml │ │ ├── fixed-fft.yaml │ │ ├── orpo-validator-iter-2.yaml │ │ ├── llama-3-bird-validator-fft.yaml │ │ ├── orpo-llama-3-validator.yaml │ │ └── orpo-validator.yaml │ ├── llama-1b-spider │ │ ├── planner-fft.yaml │ │ ├── orpo-fixed.yaml │ │ ├── fixed-fft.yaml │ │ ├── validator-fft.yaml │ │ └── orpo-validator.yaml │ └── llama-3b-spider │ │ ├── planner-fft.yaml │ │ ├── orpo-fixed.yaml │ │ ├── llama-3-fixed-fft.yaml │ │ ├── orpo-planner-iter-2.yaml │ │ ├── orpo-planner.yaml │ │ ├── orpo-planner-iter-3.yaml │ │ ├── orpo-validator.yaml │ │ └── fft-validator.yaml ├── .github │ └── workflows │ │ ├── upload_pr_documentation.yml │ │ ├── build_documentation.yml │ │ ├── build_pr_documentation.yml │ │ ├── quality.yml │ │ └── tests.yml ├── setup.cfg ├── Makefile ├── .gitignore └── setup.py ├── validator_data ├── .env ├── generate_validator_order_using_fewshot.py ├── generate_validator_select_using_fewshot.py ├── generate_fixed_sql_using_fewshot_condition.py ├── generate_validator_join_using_fewshot.py ├── generate_fixed_sql_using_fewshot_join.py ├── generate_validator_condition_using_fewshot.py ├── utils.py └── generate_fixed_sql_using_fewshot.py ├── requirements.txt ├── .gitignore ├── json2jsonl.py ├── jsonl2json.py ├── visualization ├── temperature_sensitivity.py ├── planner_prompt_comparison.py ├── sql_characteristic_spider.py ├── parameter_sensitivity_num_candidates.py ├── sql_characteristic_bird.py ├── domain_knowledge.py ├── rlef_improvement.py ├── .ipynb_checkpoints │ └── rlef_improvement-checkpoint.py └── lambda_sensitivity.py ├── llm_alignment ├── merge_rl_data.py └── build_rlef_selection_data.py ├── scripts ├── evaluate_bird.sh └── evaluate_dr_spider.sh ├── utils ├── classifier_loss.py ├── load_classifier_dataset.py ├── load_pt_dataset.py ├── lr_scheduler.py └── load_sft_dataset.py ├── data_processing ├── generate_validator_data.py ├── prompts │ └── zero_shot_prompt_planner.txt ├── generate_planner_data.py ├── merge_val_fix_data.ipynb ├── generate_validator_fixer_data.py ├── generate_sft_data_for_validator.py ├── generate_sft_data_for_fix.py └── generate_sft_data_for_planner.py ├── bird_evaluation └── run_evaluation.sh └── README.md /.env: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | -------------------------------------------------------------------------------- /alignment-handbook/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /validator_data/.env: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | -------------------------------------------------------------------------------- /alignment-handbook/README.md: -------------------------------------------------------------------------------- 1 | To train SFT on TinyLLama, run: 2 | ``` 3 | bash scripts/train_tinyllama.sh 4 | ``` 5 | 6 | -------------------------------------------------------------------------------- /alignment-handbook/chapters/en/chapter0/introduction.mdx: -------------------------------------------------------------------------------- 1 | # Welcome to the RLHF Handbook! 2 | 3 | Stay tuned for more details 🤗 -------------------------------------------------------------------------------- /alignment-handbook/assets/handbook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thanhdath/mats-sql/HEAD/alignment-handbook/assets/handbook.png -------------------------------------------------------------------------------- /alignment-handbook/chapters/en/_toctree.yml: -------------------------------------------------------------------------------- 1 | - title: Unit 0. Welcome to the RLHF Handbook! 2 | sections: 3 | - local: chapter0/introduction 4 | title: What is this about? -------------------------------------------------------------------------------- /alignment-handbook/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | apt-get install git-lfs 2 | python -m pip install . 3 | #export PATH=/usr/local/cuda/bin:$PATH 4 | #export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 5 | python -m pip install flash-attn --no-build-isolation 6 | pip install git+https://github.com/huggingface/trl.git 7 | -------------------------------------------------------------------------------- /alignment-handbook/src/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.0.dev0" 2 | 3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig 4 | from .data import apply_chat_template, get_datasets 5 | from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | func_timeout==4.3.5 2 | nltk==3.7 3 | numpy==1.23.5 4 | pandas==2.0.1 5 | rapidfuzz==2.0.11 6 | tqdm==4.63.0 7 | transformers==4.28.1 8 | sqlparse==0.4.2 9 | accelerate==0.18.0 10 | bitsandbytes==0.41.1 11 | pyserini==0.21.0 12 | sql_metadata==2.8.0 13 | datasets==2.11.0 14 | faiss-cpu==1.7.4 15 | deepspeed==0.9.5 16 | tensorboard -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | pred_sqls.txt 3 | predict_dev.json 4 | output/ 5 | test_suite_sql_eval 6 | seeklhy/ 7 | sic_ckpts/ 8 | data/ 9 | log-tensorboard/ 10 | *.json 11 | temp/ 12 | *.log 13 | *.csv 14 | db_content_retrieval/volumes 15 | offload/ 16 | *.xlsx 17 | test_suite_sql_eval 18 | alignment-handbook/thanhdath/ 19 | progress.jsonl 20 | logs/ 21 | temp.sh 22 | *.pkl 23 | *.jsonl 24 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 2 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | offload_optimizer_device: cpu 4 | offload_param_device: none 5 | zero3_init_flag: true 6 | zero_stage: 2 7 | distributed_type: DEEPSPEED 8 | fsdp_config: {} 9 | machine_rank: 0 10 | main_process_ip: null 11 | main_process_port: null 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 2 16 | use_cpu: false -------------------------------------------------------------------------------- /json2jsonl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--json-file', type=str, required=True) 6 | args = parser.parse_args() 7 | 8 | file = args.json_file 9 | 10 | # Read the JSON file 11 | with open(file, 'r') as f: 12 | data = json.load(f) 13 | 14 | # Export to JSONL 15 | with open(file.replace('.json', '.jsonl'), 'w') as f: 16 | for record in data: 17 | f.write(json.dumps(record) + '\n') 18 | -------------------------------------------------------------------------------- /jsonl2json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--jsonl-file', type=str, required=True) 6 | args = parser.parse_args() 7 | 8 | file = args.jsonl_file 9 | data = [] 10 | with open(file, 'r') as f: 11 | for line in f: 12 | data.append(json.loads(line)) 13 | # export to json, replace jsonl to json 14 | with open(file.replace('jsonl', 'json'), 'w') as f: 15 | json.dump(data, f, indent=4) 16 | -------------------------------------------------------------------------------- /alignment-handbook/.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: alignment-handbook 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/deepspeed_dpo.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: ./recipes/accelerate_configs/dpo.json 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: ./recipes/accelerate_configs/ds_7b.json 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/no_offload_optimizer.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: ./recipes/accelerate_configs/ds_7b.json 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /alignment-handbook/scripts/test_gamma_bird.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=src/ 2 | export CUDA_VISIBLE_DEVICES=0,1 3 | 4 | for beta in 0.0 0.25 0.5 1.0 0.75 5 | do 6 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-3b-bird/orpo-planner.yaml --model_name_or_path=./output/reproduce/llama-3b-bird-planner-fft-no-filter --output_dir=output/param-sensitivity/orpo-llama-3b-bird-planner-beta-$beta --save_strategy=no 7 | done 8 | -------------------------------------------------------------------------------- /alignment-handbook/.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 11 | with: 12 | commit_sha: ${{ github.sha }} 13 | package: alignment-handbook 14 | path_to_docs: alignment-handbook/chapters/ 15 | additional_args: --not_python_module 16 | languages: en 17 | secrets: 18 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 19 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/accelerate_configs/deepspeed_original.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 2 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false -------------------------------------------------------------------------------- /alignment-handbook/scripts/train_spider.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export PYTHONPATH=src/ 3 | #accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_sft.py recipes/llama-3b-spider/planner-fft.yaml 4 | accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 1 scripts/run_orpo.py recipes/llama-3b-spider/orpo-planner-iter-3.yaml 5 | # accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_sft.py recipes/llama-3b-spider/sql-gt-fft.yaml 6 | 7 | 8 | -------------------------------------------------------------------------------- /alignment-handbook/.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: alignment-handbook 17 | path_to_docs: alignment-handbook/chapters/ 18 | additional_args: --not_python_module 19 | languages: en -------------------------------------------------------------------------------- /alignment-handbook/.github/workflows/quality.yml: -------------------------------------------------------------------------------- 1 | name: Quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - v*-release 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | 14 | check_code_quality: 15 | name: Check code quality 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v2 20 | - name: Setup Python environment 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.10.10 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install ".[quality]" 28 | - name: Code quality 29 | run: | 30 | make quality 31 | 32 | -------------------------------------------------------------------------------- /alignment-handbook/.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - v*-release 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | 14 | unit-tests: 15 | name: Run unit tests 16 | env: 17 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@v2 22 | - name: Setup Python environment 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: 3.10.10 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install ".[dev, torch]" 30 | - name: Run unit tests 31 | run: HF_TOKEN=$HF_TOKEN pytest -sv tests/ -------------------------------------------------------------------------------- /alignment-handbook/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = alignment 7 | known_third_party = 8 | transformers 9 | datasets 10 | fugashi 11 | git 12 | h5py 13 | matplotlib 14 | nltk 15 | numpy 16 | packaging 17 | pandas 18 | psutil 19 | pytest 20 | rouge_score 21 | sacrebleu 22 | seqeval 23 | sklearn 24 | streamlit 25 | torch 26 | tqdm 27 | 28 | line_length = 119 29 | lines_after_imports = 2 30 | multi_line_output = 3 31 | use_parentheses = True 32 | 33 | [flake8] 34 | ignore = E203, E501, E741, W503, W605 35 | max-line-length = 119 36 | per-file-ignores = 37 | # imported but unused 38 | __init__.py: F401 39 | 40 | [tool:pytest] 41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /visualization/temperature_sensitivity.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Given data 4 | x_values = [0, 0.3, 0.5, 1.0] # Temperature values 5 | y_values = [59.12, 63.58, 63.98, 64.13] # Mean values (A) 6 | y_errors = [0.09, 0.13, 0.23, 0.25] # Standard deviation (B) 7 | 8 | # Adjust figure size for 1-column fit in a 2-column research paper (~3.5 inches wide) 9 | plt.figure(figsize=(3.5, 2.5)) 10 | 11 | # Create error bar plot 12 | plt.errorbar(x_values, y_values, yerr=y_errors, fmt='o-', capsize=3, capthick=1) 13 | 14 | # Labels with optimized font size for readability 15 | plt.xlabel("Temperature", fontsize=9) 16 | plt.ylabel("EX%", fontsize=9) 17 | plt.xticks(x_values, fontsize=8) 18 | plt.yticks(fontsize=8) 19 | plt.grid(True, linestyle="--", alpha=0.7) 20 | 21 | # Tight layout for better fit in a research paper column 22 | plt.tight_layout() 23 | 24 | # Show the plot 25 | plt.show() 26 | -------------------------------------------------------------------------------- /alignment-handbook/tests/fixtures/config_dpo_full.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: alignment-handbook/zephyr-7b-sft-full 3 | 4 | # Data training arguments 5 | # For definitions, see: src/h4/training/config.py 6 | dataset_mixer: 7 | HuggingFaceH4/ultrafeedback_binarized: 1.0 8 | dataset_splits: 9 | - train_prefs 10 | - test_prefs 11 | preprocessing_num_workers: 12 12 | 13 | # DPOTrainer arguments 14 | bf16: true 15 | beta: 0.1 16 | do_eval: true 17 | evaluation_strategy: steps 18 | eval_steps: 100 19 | gradient_accumulation_steps: 1 20 | gradient_checkpointing: true 21 | hub_model_id: zephyr-7b-dpo-full 22 | learning_rate: 5.0e-7 23 | log_level: info 24 | logging_steps: 10 25 | lr_scheduler_type: linear 26 | max_length: 1024 27 | max_prompt_length: 512 28 | num_train_epochs: 3 29 | optim: rmsprop 30 | output_dir: data/zephyr-7b-dpo-full 31 | per_device_train_batch_size: 8 32 | per_device_eval_batch_size: 4 33 | push_to_hub: true 34 | save_strategy: "no" 35 | save_total_limit: null 36 | seed: 42 37 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /alignment-handbook/tests/fixtures/config_sft_full.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: mistralai/Mistral-7B-v0.1 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | 7 | # Data training arguments 8 | dataset_mixer: 9 | HuggingFaceH4/ultrachat_200k: 1.0 10 | dataset_splits: 11 | - train_sft 12 | - test_sft 13 | preprocessing_num_workers: 12 14 | 15 | # SFT trainer config 16 | bf16: true 17 | do_eval: true 18 | evaluation_strategy: epoch 19 | gradient_accumulation_steps: 2 20 | gradient_checkpointing: true 21 | hub_model_id: zephyr-7b-sft-full 22 | hub_strategy: every_save 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 2048 29 | max_steps: -1 30 | num_train_epochs: 1 31 | output_dir: data/zephyr-7b-sft-full 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 16 34 | per_device_train_batch_size: 32 35 | push_to_hub: true 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "no" 40 | save_total_limit: null 41 | seed: 42 -------------------------------------------------------------------------------- /alignment-handbook/scripts/train_bird.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=src/ 2 | export CUDA_VISIBLE_DEVICES=0,1 3 | # ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 1 scripts/run_orpo.py recipes/llama-3b-bird/orpo-planner.yaml --model_name_or_path=./output/llama-3b-bird-planner-fft --output_dir=output/reproduce/orpo-llama-3b-bird-planner 4 | 5 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-bird/orpo-validator.yaml 6 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-bird/orpo-fixed.yaml 7 | 8 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-spider/orpo-validator.yaml 9 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-spider/orpo-fixed.yaml -------------------------------------------------------------------------------- /visualization/planner_prompt_comparison.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | # Data from the table 5 | strategies = ["No Thought", "Chain-of-Thought", "Few-shot Thoughts"] 6 | categories = ["simple", "moderate", "challenging", "overall"] 7 | 8 | # Execution accuracy scores 9 | scores = np.array([ 10 | [59.46, 37.5, 30.34, 50.07], 11 | [60.11, 40.09, 37.93, 51.96], 12 | [63.03, 43.75, 38.62, 54.89] 13 | ]) 14 | 15 | # Bar width and positions 16 | bar_width = 0.25 17 | x = np.arange(len(categories)) 18 | 19 | # Creating figure with aspect ratio suitable for a research paper (single-column) 20 | fig, ax = plt.subplots(figsize=(4.5, 2.5)) 21 | 22 | # Plot bars for each strategy 23 | for i, strategy in enumerate(strategies): 24 | ax.bar(x + i * bar_width, scores[i], width=bar_width, label=strategy) 25 | 26 | # Labels and formatting 27 | ax.set_xticks(x + bar_width) 28 | ax.set_xticklabels(categories, fontsize=9) 29 | ax.set_ylabel("EX%", fontsize=9) 30 | 31 | # Move legend to the top 32 | ax.legend(fontsize=8, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, frameon=True) 33 | 34 | # Tight layout for better fit 35 | plt.tight_layout() 36 | 37 | # Show the plot 38 | plt.show() 39 | -------------------------------------------------------------------------------- /alignment-handbook/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: style quality 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src tests scripts 7 | 8 | style: 9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py 10 | isort $(check_dirs) setup.py 11 | 12 | quality: 13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py 14 | isort --check-only $(check_dirs) setup.py 15 | flake8 --max-line-length 119 $(check_dirs) setup.py 16 | 17 | 18 | # Release stuff 19 | 20 | pre-release: 21 | python src/alignment/release.py 22 | 23 | pre-patch: 24 | python src/alignment/release.py --patch 25 | 26 | post-release: 27 | python src/alignment/release.py --post_release 28 | 29 | post-patch: 30 | python src/alignment/release.py --post_release --patch 31 | 32 | wheels: 33 | python setup.py bdist_wheel && python setup.py sdist 34 | 35 | wheels_clean: 36 | rm -rf build && rm -rf dist 37 | 38 | pypi_upload: 39 | python -m pip install twine 40 | twine upload dist/* -r pypi 41 | 42 | pypi_test_upload: 43 | python -m pip install twine 44 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 45 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/fixed-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | dataset_mixer: 9 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-bird_with_evidence: 1 10 | dataset_splits: 11 | - train 12 | - test 13 | preprocessing_num_workers: 24 14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 15 | 16 | # SFT trainer config 17 | bf16: true 18 | do_eval: false 19 | evaluation_strategy: "no" 20 | gradient_accumulation_steps: 64 21 | gradient_checkpointing: true 22 | hub_strategy: every_save 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 4 31 | output_dir: output/llama-1b-bird-fixed-fft 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 1 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "no" 40 | save_total_limit: 1 41 | seed: 42 42 | tf32: true 43 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/selection-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/selection/sft_bird: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: false 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 64 22 | gradient_checkpointing: true 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 1 31 | output_dir: output/llama-1b-bird-selection-fft 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 1 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "steps" 40 | save_steps: 200 41 | save_total_limit: 10 42 | seed: 42 43 | tf32: true 44 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/selection-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/selection/sft_ranking_bird: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: false 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 32 22 | gradient_checkpointing: true 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 2 31 | output_dir: output/llama-3b-bird-selection-fft 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 2 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "steps" 40 | save_steps: 100 41 | save_total_limit: 1 42 | seed: 42 43 | tf32: true 44 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/validator-fixer-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | dataset_mixer: 9 | ../data/multi-agents/fixed/sft-validator-fixer-bird_with_evidence/: 1 10 | dataset_splits: 11 | - train 12 | - test 13 | preprocessing_num_workers: 24 14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 15 | 16 | # SFT trainer config 17 | bf16: true 18 | do_eval: false 19 | evaluation_strategy: "no" 20 | gradient_accumulation_steps: 64 21 | gradient_checkpointing: true 22 | hub_strategy: every_save 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 4 31 | output_dir: output/llama-1b-bird-validator-fixer-fft 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 1 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "no" 40 | save_total_limit: 1 41 | seed: 42 42 | tf32: true 43 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/validator-fixer-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | dataset_mixer: 9 | ../data/multi-agents/fixed/sft-validator-fixer-bird_with_evidence/: 1 10 | dataset_splits: 11 | - train 12 | - test 13 | preprocessing_num_workers: 24 14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 15 | 16 | # SFT trainer config 17 | bf16: true 18 | do_eval: false 19 | evaluation_strategy: "no" 20 | gradient_accumulation_steps: 64 21 | gradient_checkpointing: true 22 | hub_strategy: every_save 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 4 31 | output_dir: output/llama-3b-bird-validator-fixer-fft 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 1 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "no" 40 | save_total_limit: 1 41 | seed: 42 42 | tf32: true 43 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-spider/planner-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|reserved_special_token_247|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/planner/sft-gpt-4o-mini-planner_spider_train: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{messages['prompt'] + '<|reserved_special_token_247|>\n'}}{{messages['completion'] + '<|end_of_text|>'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: "no" 21 | gradient_accumulation_steps: 8 22 | gradient_checkpointing: true 23 | hub_model_id: griffith-bigdata/llama-1b-spider-planner-fft 24 | hub_strategy: every_save 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 2048 31 | max_steps: -1 32 | num_train_epochs: 1 33 | output_dir: output/llama-1b-spider-planner-fft-1epoch 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 8 36 | per_device_train_batch_size: 8 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "epoch" 42 | save_steps: 20 43 | save_total_limit: 1 44 | seed: 42 45 | tf32: true 46 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-selection.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-3b-bird-selection-cot-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p1-selection: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 24 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: cosine 33 | max_length: 2300 34 | max_prompt_length: 1700 35 | num_train_epochs: 1 36 | max_steps: -1 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3b-bird-selection 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 200 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.05 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/planner-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/planner/sft-gpt-4o-mini-planner_spider_train_no_filter/: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: "no" 21 | gradient_accumulation_steps: 32 22 | gradient_checkpointing: true 23 | learning_rate: 5.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 2048 29 | max_steps: -1 30 | num_train_epochs: 4 31 | output_dir: output/reproduce/llama-3b-spider-planner-fft-no-filter 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 4 34 | per_device_train_batch_size: 2 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "steps" 40 | save_steps: 20 41 | save_total_limit: 4 42 | seed: 42 43 | tf32: true 44 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/orpo-fixed.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-1b-bird-fixed-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p1-fix/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | # chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 3300 34 | max_prompt_length: 3000 35 | num_train_epochs: 1 36 | max_steps: 500 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-1b-fixed-bird/ 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.01 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-spider/orpo-fixed.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-1b-spider-fixed-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/spider-p1-fix/dpo-llama-3-end2end-spider_train_dev_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 16 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 5.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 2500 34 | max_prompt_length: 2000 35 | num_train_epochs: -1 36 | max_steps: 500 37 | optim: adamw_torch 38 | output_dir: output/orpo-llama-1b-fixed-spider 39 | per_device_train_batch_size: 2 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.1 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-fixed.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-3b-bird-fixed-fft-follow-validation/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p1-fix/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 16 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 3300 34 | max_prompt_length: 3000 35 | num_train_epochs: 1 36 | max_steps: 500 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3-fixed-bird/ 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.01 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-spider/fixed-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-spider/: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: "no" 21 | gradient_accumulation_steps: 64 22 | gradient_checkpointing: true 23 | hub_model_id: griffith-bigdata/llama-1b-spider-fixed-fft 24 | hub_strategy: every_save 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 4096 31 | max_steps: -1 32 | num_train_epochs: 2 33 | output_dir: output/llama-1b-spider-fixed-fft 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 2 36 | per_device_train_batch_size: 2 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "no" 42 | save_total_limit: 1 43 | seed: 42 44 | tf32: true 45 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-planner-iter-2.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/reproduce/orpo-llama-3b-bird-planner 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p2-planner/dpo-llama-3-end2end-bird_train_planner: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "no" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: cosine 33 | max_length: 2300 34 | max_prompt_length: 1700 35 | num_train_epochs: 1 36 | max_steps: -1 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3b-iter-2-bird-planner 39 | per_device_train_batch_size: 2 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.05 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/orpo-fixed.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-3b-spider-fixed-fft-follow-validation/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/spider-p1-fix/dpo-llama-3-end2end-spider_train_dev_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 0.25 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 16 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 5.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 2500 34 | max_prompt_length: 2000 35 | num_train_epochs: -1 36 | max_steps: 800 37 | optim: adamw_torch 38 | output_dir: output/orpo-llama-3-fixed-spider 39 | per_device_train_batch_size: 2 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.1 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-planner-iter-3.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/reproduce/orpo-llama-3b-iter-2-bird-planner 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p3-planner/dpo-llama-3-end2end-bird_train_planner: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "no" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 16 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: cosine 33 | max_length: 2300 34 | max_prompt_length: 1700 35 | num_train_epochs: 1 36 | max_steps: -1 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3b-iter-3-bird-planner 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "no" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.05 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-fixed-iter-2.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: output/reproduce/orpo-llama-3-fixed-bird/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird/bird-p2-fix/dpo-llama-3-end2end-bird_train_dev_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 0.25 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 3300 34 | max_prompt_length: 3000 35 | num_train_epochs: -1 36 | max_steps: 500 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3-iter-2-fixed-bird/ 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.1 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-planner.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/reproduce/llama-3b-bird-planner-fft-no-filter 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p1-planner/dpo-llama-3-end2end-bird_train_planner: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 24 15 | 16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 1.0 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: cosine 33 | max_length: 2300 34 | max_prompt_length: 1700 35 | num_train_epochs: 1 36 | max_steps: -1 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-3b-bird-planner-no-filter 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 1 45 | seed: 42 46 | warmup_ratio: 0.05 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/planner-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_with_true_sql_bird_with_evidence_train_no_filter/: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 32 22 | gradient_checkpointing: true 23 | learning_rate: 2.0e-05 24 | log_level: info 25 | logging_steps: 5 26 | logging_strategy: steps 27 | lr_scheduler_type: cosine 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 4 31 | output_dir: output/reproduce/llama-3b-bird-planner-fft-no-filter 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 2 34 | per_device_train_batch_size: 2 35 | push_to_hub: false 36 | remove_unused_columns: true 37 | report_to: 38 | - tensorboard 39 | save_strategy: "epoch" 40 | save_steps: 40 41 | save_total_limit: 1 42 | seed: 42 43 | tf32: true -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/fixed-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-bird_with_evidence: 1 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 64 22 | gradient_checkpointing: true 23 | hub_model_id: griffith-bigdata/llama-3b-bird-fixed-fft 24 | hub_strategy: every_save 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 4096 31 | max_steps: -1 32 | num_train_epochs: 2 33 | output_dir: output/llama-3b-bird-fixed-fft-follow-validation 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 2 36 | per_device_train_batch_size: 1 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "no" 42 | save_total_limit: 1 43 | seed: 42 44 | tf32: true 45 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/llama-3-fixed-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-spider/: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: "no" 21 | gradient_accumulation_steps: 32 22 | gradient_checkpointing: true 23 | hub_model_id: griffith-bigdata/llama-3b-spider-fixed-fft 24 | hub_strategy: every_save 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 4096 31 | max_steps: -1 32 | num_train_epochs: 2 33 | output_dir: output/llama-3b-spider-fixed-fft-follow-validation 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 2 36 | per_device_train_batch_size: 2 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "no" 42 | save_total_limit: 1 43 | seed: 42 44 | tf32: true 45 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/orpo-validator-fixed.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-1b-bird-validator-fixer-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird-p1-validator-fixer/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0 10 | 11 | dataset_splits: 12 | - train_dpo 13 | - test_dpo 14 | preprocessing_num_workers: 12 15 | 16 | # chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 17 | report_to: ["tensorboard"] 18 | 19 | # DPOTrainer arguments 20 | bf16: true 21 | beta: 100 22 | do_eval: true 23 | eval_strategy: "steps" 24 | eval_steps: 100 25 | gradient_accumulation_steps: 16 26 | gradient_checkpointing: true 27 | gradient_checkpointing_kwargs: 28 | use_reentrant: False 29 | learning_rate: 8.0e-6 30 | log_level: info 31 | logging_steps: 10 32 | lr_scheduler_type: inverse_sqrt 33 | max_length: 3300 34 | max_prompt_length: 2000 35 | num_train_epochs: 2 36 | max_steps: 100000 37 | optim: adamw_torch 38 | output_dir: output/reproduce/orpo-llama-1b-validator-fixer-bird-beta0.5/ 39 | per_device_train_batch_size: 1 40 | per_device_eval_batch_size: 1 41 | push_to_hub: false 42 | save_strategy: "steps" 43 | save_steps: 100 44 | save_total_limit: 2 45 | seed: 42 46 | warmup_ratio: 0.01 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/orpo-planner-iter-2.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: output/reproduce/orpo-3b-spider-planner 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | # For definitions, see: src/h4/training/config.py 9 | dataset_mixer: 10 | ../data/llm_alignment/spider-p2-planner/dpo-llama-3-end2end-spider_train_planner: 1.0 11 | 12 | dataset_splits: 13 | - train_dpo 14 | - test_dpo 15 | preprocessing_num_workers: 24 16 | 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 1.0 23 | do_eval: true 24 | eval_strategy: "no" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 4 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | learning_rate: 8.0e-6 31 | log_level: info 32 | logging_steps: 10 33 | lr_scheduler_type: cosine 34 | max_length: 1600 35 | max_prompt_length: 1200 36 | num_train_epochs: 1 37 | max_steps: -1 38 | optim: adamw_torch 39 | output_dir: output/reproduce/orpo-3b-spider-planner-iter-2 40 | per_device_train_batch_size: 2 41 | per_device_eval_batch_size: 2 42 | push_to_hub: false 43 | save_strategy: "epoch" 44 | save_steps: 100 45 | save_total_limit: 2 46 | seed: 42 47 | warmup_ratio: 0.05 48 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/llama-3-bird-planner-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_with_true_sql_bird_062024_with_evidence_train/: 1.0 11 | dataset_splits: 12 | - train 13 | - test 14 | preprocessing_num_workers: 24 15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 16 | 17 | # SFT trainer config 18 | bf16: true 19 | do_eval: true 20 | evaluation_strategy: "no" 21 | gradient_accumulation_steps: 32 22 | gradient_checkpointing: true 23 | hub_model_id: griffith-bigdata/llama-1b-bird-planner-fft 24 | hub_strategy: every_save 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 8192 31 | max_steps: -1 32 | num_train_epochs: 4 33 | output_dir: output/llama-1b-bird-planner-fft 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 4 36 | per_device_train_batch_size: 2 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "no" 42 | save_total_limit: 1 43 | seed: 42 44 | tf32: true -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/orpo-planner.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: output/reproduce/llama-3b-spider-planner-fft-no-filter 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | # For definitions, see: src/h4/training/config.py 9 | dataset_mixer: 10 | ../data/llm_alignment/spider-p1-planner/dpo-llama-3-end2end-spider_train_dev_planner: 1.0 11 | 12 | dataset_splits: 13 | - train_dpo 14 | - test_dpo 15 | preprocessing_num_workers: 24 16 | 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 1.0 23 | do_eval: true 24 | eval_strategy: "no" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 4 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | learning_rate: 8.0e-6 31 | log_level: info 32 | logging_steps: 10 33 | lr_scheduler_type: cosine 34 | max_length: 1600 35 | max_prompt_length: 1200 36 | num_train_epochs: 1 37 | max_steps: -1 38 | optim: adamw_torch 39 | output_dir: output/reproduce/orpo-3b-spider-planner 40 | per_device_train_batch_size: 2 41 | per_device_eval_batch_size: 2 42 | push_to_hub: false 43 | save_strategy: "epoch" 44 | save_steps: 100 45 | save_total_limit: 2 46 | seed: 42 47 | warmup_ratio: 0.05 48 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/orpo-planner-iter-3.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: output/reproduce/orpo-3b-spider-planner-iter-2 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 6 | 7 | # Data training arguments 8 | # For definitions, see: src/h4/training/config.py 9 | dataset_mixer: 10 | ../data/llm_alignment/spider-p3-planner/dpo-llama-3-end2end-spider_train_planner: 1.0 11 | 12 | dataset_splits: 13 | - train_dpo 14 | - test_dpo 15 | preprocessing_num_workers: 24 16 | 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 1.0 23 | do_eval: true 24 | eval_strategy: "no" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 8 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | learning_rate: 8.0e-6 31 | log_level: info 32 | logging_steps: 10 33 | lr_scheduler_type: cosine 34 | max_length: 1600 35 | max_prompt_length: 1200 36 | num_train_epochs: 1 37 | max_steps: -1 38 | optim: adamw_torch 39 | output_dir: output/reproduce/orpo-3b-spider-planner-iter-3 40 | per_device_train_batch_size: 2 41 | per_device_eval_batch_size: 2 42 | push_to_hub: false 43 | save_strategy: "epoch" 44 | save_steps: 100 45 | save_total_limit: 2 46 | seed: 42 47 | warmup_ratio: 0.05 48 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-spider/validator-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/validator/sft-validator_select_spider/: 1.0 11 | ../data/multi-agents/validator/sft-validator_join_spider/: 1.0 12 | ../data/multi-agents/validator/sft-validator_condition_spider/: 1.0 13 | dataset_splits: 14 | - train 15 | - test 16 | preprocessing_num_workers: 24 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | 19 | # SFT trainer config 20 | bf16: true 21 | do_eval: true 22 | evaluation_strategy: "no" 23 | gradient_accumulation_steps: 64 24 | gradient_checkpointing: true 25 | learning_rate: 2.0e-05 26 | log_level: info 27 | logging_steps: 5 28 | logging_strategy: steps 29 | lr_scheduler_type: cosine 30 | max_seq_length: 2048 31 | max_steps: -1 32 | num_train_epochs: 4 33 | output_dir: output/llama-1b-spider-validator-fft 34 | overwrite_output_dir: true 35 | per_device_eval_batch_size: 4 36 | per_device_train_batch_size: 2 37 | push_to_hub: false 38 | remove_unused_columns: true 39 | report_to: 40 | - tensorboard 41 | save_strategy: "epoch" 42 | save_steps: 100 43 | save_total_limit: 1 44 | seed: 42 45 | tf32: true -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/orpo-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-3b-spider-validator-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_select/: 1.0 10 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_condition/: 1.0 11 | 12 | dataset_splits: 13 | - train_dpo 14 | - test_dpo 15 | preprocessing_num_workers: 12 16 | 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 0.25 23 | do_eval: true 24 | eval_strategy: "no" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 16 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | hub_model_id: th-dpo 31 | learning_rate: 8.0e-6 32 | log_level: info 33 | logging_steps: 10 34 | lr_scheduler_type: cosine 35 | max_length: 1600 36 | max_prompt_length: 1200 37 | num_train_epochs: 1 38 | max_steps: -1 39 | optim: adamw_torch 40 | output_dir: output/reproduce/orpo-llama-3-validator-spider 41 | per_device_train_batch_size: 1 42 | per_device_eval_batch_size: 1 43 | push_to_hub: false 44 | save_strategy: "steps" 45 | save_steps: 100 46 | save_total_limit: 1 47 | seed: 42 48 | warmup_ratio: 0.05 49 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-spider/orpo-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-1b-spider-validator-fft/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_select/: 1.0 10 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_condition/: 1.0 11 | 12 | dataset_splits: 13 | - train_dpo 14 | - test_dpo 15 | preprocessing_num_workers: 12 16 | 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 1.0 23 | do_eval: true 24 | eval_strategy: "no" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 8 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | hub_model_id: th-dpo 31 | learning_rate: 8.0e-6 32 | log_level: info 33 | logging_steps: 10 34 | lr_scheduler_type: cosine 35 | max_length: 1600 36 | max_prompt_length: 1200 37 | # num_train_epochs: 1 38 | max_steps: 500 39 | optim: adamw_torch 40 | output_dir: output/reproduce/orpo-llama-1b-validator-spider 41 | per_device_train_batch_size: 1 42 | per_device_eval_batch_size: 1 43 | push_to_hub: false 44 | save_strategy: "steps" 45 | save_steps: 100 46 | save_total_limit: 1 47 | seed: 42 48 | warmup_ratio: 0.05 49 | -------------------------------------------------------------------------------- /llm_alignment/merge_rl_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | import argparse 3 | import datasets 4 | import numpy as np 5 | from datasets import Dataset 6 | 7 | # add arguments data/llm_alignment/spider-p1 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data_dir", type=str, help="path to the data directory") 10 | args = parser.parse_args() 11 | 12 | # read all train data from the data directory, including 13 | # dpo-llama-3-end2end-spider_train_fixed_sql 14 | # dpo-llama-3-end2end-spider_train_planner 15 | # dpo-llama-3-end2end-spider_train_validator_condition 16 | # dpo-llama-3-end2end-spider_train_validator_join 17 | # dpo-llama-3-end2end-spider_train_validator_select 18 | # dpo-llama-3-end2end-spider_train_validator_order 19 | 20 | import glob 21 | import os 22 | data_dirs = glob.glob(args.data_dir + "/*train*") 23 | data_dirs = [x for x in data_dirs if os.path.isdir(x)] 24 | print(data_dirs) 25 | 26 | for data_dir in data_dirs: 27 | dataset_train = load_from_disk(data_dir) 28 | # load dev data 29 | dev_file = data_dir.replace("train", "dev") 30 | if os.path.exists(dev_file): 31 | dataset_dev = load_from_disk(dev_file) 32 | dataset_dev = list(dataset_dev['train_dpo']) 33 | dataset_dev = np.random.permutation(dataset_dev)[:2000].tolist() 34 | dataset_train['test_dpo'] = Dataset.from_list(dataset_dev) 35 | 36 | print(data_dir) 37 | print(dataset_train) 38 | 39 | # save the merged data to other directory 40 | dataset_train.save_to_disk(data_dir.replace("train", "train_dev")) 41 | print(data_dir.replace("train", "train_dev")) 42 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-validator-iter-2.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/reproduce/orpo-llama-3-validator-bird/ 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_select/: 1.0 10 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0 11 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_join/: 1.0 12 | 13 | dataset_splits: 14 | - train_dpo 15 | - test_dpo 16 | preprocessing_num_workers: 12 17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 18 | report_to: ["tensorboard"] 19 | 20 | # DPOTrainer arguments 21 | bf16: true 22 | beta: 0.25 23 | do_eval: true 24 | eval_strategy: "steps" 25 | eval_steps: 100 26 | gradient_accumulation_steps: 8 27 | gradient_checkpointing: true 28 | gradient_checkpointing_kwargs: 29 | use_reentrant: False 30 | learning_rate: 8.0e-6 31 | log_level: info 32 | logging_steps: 10 33 | lr_scheduler_type: inverse_sqrt 34 | max_length: 2600 35 | max_prompt_length: 2000 36 | num_train_epochs: -1 37 | max_steps: 500 38 | optim: adamw_torch 39 | output_dir: output/reproduce/orpo-iter-2-llama-3-validator-bird/ 40 | per_device_train_batch_size: 1 41 | per_device_eval_batch_size: 1 42 | push_to_hub: false 43 | save_strategy: "steps" 44 | save_steps: 100 45 | save_total_limit: 1 46 | seed: 42 47 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-spider/fft-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | ../data/multi-agents/validator/sft-validator_select_spider/: 1.0 11 | ../data/multi-agents/validator/sft-validator_join_spider/: 1.0 12 | ../data/multi-agents/validator/sft-validator_condition_spider/: 1.0 13 | ../data/multi-agents/validator/sft-validator_order_spider/: 1.0 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 24 18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 19 | 20 | # SFT trainer config 21 | bf16: true 22 | do_eval: true 23 | evaluation_strategy: "no" 24 | gradient_accumulation_steps: 32 25 | gradient_checkpointing: true 26 | hub_model_id: griffith-bigdata/llama-3b-spider-validator-fft 27 | hub_strategy: every_save 28 | learning_rate: 2.0e-05 29 | log_level: info 30 | logging_steps: 5 31 | logging_strategy: steps 32 | lr_scheduler_type: cosine 33 | max_seq_length: 2048 34 | max_steps: -1 35 | num_train_epochs: 4 36 | output_dir: output/llama-3b-spider-validator-fft 37 | overwrite_output_dir: true 38 | per_device_eval_batch_size: 4 39 | per_device_train_batch_size: 2 40 | push_to_hub: false 41 | remove_unused_columns: true 42 | report_to: 43 | - tensorboard 44 | save_strategy: "epoch" 45 | save_steps: 100 46 | save_total_limit: 3 47 | seed: 42 48 | tf32: true -------------------------------------------------------------------------------- /visualization/sql_characteristic_spider.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | 5 | # Sample dataset (Ensure to replace with actual data) 6 | data = { 7 | "Method": ["MATS (Ours)", "DAILSQL(SC)", "CodeS-15B", "CodeS-7B", "REDSQL-3B\n+NatSQL", "REDSQL-3B", "Graphix\n+PICARD"], 8 | "w/o JOIN": [92.49, 89.1, 90.6, 89.6, 89.0, 90.1, 88.3], 9 | "w/ JOIN": [79.9, 75.0, 76.2, 78.9, 76.7, 69.1, 69.6], 10 | "w/o Subquery": [88.85, 84.2, 86.0, 86.7, 85.8, 83.2, 82.1], 11 | "w/ Subquery": [72.29, 63.6, 51.5, 45.5, 33.3, 39.4, 45.5], 12 | "w/o Logical\nConnector": [88.98, 85.3, 86.4, 87.0, 85.8, 83.9, 83.1], 13 | "w/ Logical\nConnector": [72.22, 65.6, 68.9, 68.9, 66.7, 60.0, 58.9], 14 | "w/o ORDER-BY": [88.08, 84.3, 85.1, 86.3, 83.6, 81.7, 80.9], 15 | "w/ ORDER-BY": [85.65, 81.0, 84.4, 82.3, 86.1, 82.3, 81.0], 16 | "Overall": [87.1, 83.6, 84.9, 85.4, 84.1, 81.8, 80.9] 17 | } 18 | 19 | # Convert to DataFrame 20 | df = pd.DataFrame(data) 21 | df.set_index("Method", inplace=True) 22 | 23 | # Transpose DataFrame to swap axes 24 | df = df.T 25 | 26 | # Remove duplicates by stripping subset names 27 | df.index = df.index.str.strip() 28 | df = df.loc[~df.index.duplicated(keep='first')] 29 | 30 | # Set up the figure size 31 | plt.figure(figsize=(4.5, 3.5)) 32 | 33 | # Create the heatmap 34 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False) 35 | 36 | # Labels 37 | plt.xlabel("", fontsize=8) 38 | plt.ylabel("Subset", fontsize=8) 39 | 40 | # Rotate x-axis labels for better readability 41 | plt.xticks(rotation=90, ha="right", fontsize=6) 42 | plt.yticks(fontsize=6) 43 | 44 | # Show the plot 45 | plt.tight_layout() 46 | plt.show() 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/validator-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1 11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1 12 | /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1 13 | # /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 24 18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 19 | 20 | # SFT trainer config 21 | bf16: true 22 | do_eval: false 23 | evaluation_strategy: "no" 24 | gradient_accumulation_steps: 32 25 | gradient_checkpointing: true 26 | hub_model_id: griffith-bigdata/llama-3b-bird-validator-fft 27 | hub_strategy: every_save 28 | learning_rate: 2.0e-05 29 | log_level: info 30 | logging_steps: 5 31 | logging_strategy: steps 32 | lr_scheduler_type: cosine 33 | max_seq_length: 4096 34 | max_steps: -1 35 | num_train_epochs: 4 36 | output_dir: output/llama-1b-bird-validator-fft 37 | overwrite_output_dir: true 38 | per_device_eval_batch_size: 2 39 | per_device_train_batch_size: 4 40 | push_to_hub: false 41 | remove_unused_columns: true 42 | report_to: 43 | - tensorboard 44 | save_strategy: "epoch" 45 | save_total_limit: 1 46 | seed: 42 -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/llama-3-bird-validator-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | # /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1 11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1 12 | # /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1 13 | # /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 24 18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 19 | 20 | # SFT trainer config 21 | bf16: true 22 | do_eval: true 23 | evaluation_strategy: "no" 24 | gradient_accumulation_steps: 32 25 | gradient_checkpointing: true 26 | hub_model_id: griffith-bigdata/llama-3b-bird-validator-fft 27 | hub_strategy: every_save 28 | learning_rate: 2.0e-05 29 | log_level: info 30 | logging_steps: 5 31 | logging_strategy: steps 32 | lr_scheduler_type: cosine 33 | max_seq_length: 4096 34 | max_steps: -1 35 | num_train_epochs: 4 36 | output_dir: output/llama-3b-bird-validator-fft-2 37 | overwrite_output_dir: true 38 | per_device_eval_batch_size: 2 39 | per_device_train_batch_size: 2 40 | push_to_hub: false 41 | remove_unused_columns: true 42 | report_to: 43 | - tensorboard 44 | save_strategy: "no" 45 | save_total_limit: 1 46 | seed: 42 47 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/llama-3-bird-validator-fft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments: 2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | response_template: "<|start_header_id|>assistant<|end_header_id|>" 7 | 8 | # Data training arguments 9 | dataset_mixer: 10 | /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1 11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1 12 | /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1 13 | /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 24 18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 19 | 20 | # SFT trainer config 21 | bf16: true 22 | do_eval: true 23 | evaluation_strategy: "no" 24 | gradient_accumulation_steps: 32 25 | gradient_checkpointing: true 26 | hub_model_id: griffith-bigdata/llama-1b-bird-validator-fft 27 | hub_strategy: every_save 28 | learning_rate: 2.0e-05 29 | log_level: info 30 | logging_steps: 5 31 | logging_strategy: steps 32 | lr_scheduler_type: cosine 33 | max_seq_length: 8192 34 | max_steps: -1 35 | num_train_epochs: 4 36 | output_dir: output/llama-1b-bird-validator-fft 37 | overwrite_output_dir: true 38 | per_device_eval_batch_size: 2 39 | per_device_train_batch_size: 2 40 | push_to_hub: false 41 | remove_unused_columns: true 42 | report_to: 43 | - tensorboard 44 | save_strategy: "no" 45 | save_total_limit: 1 46 | seed: 42 47 | tf32: true 48 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-llama-3-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | # model_name_or_path: ./output/llama-3b-bird-validator-fft-1epoch/ 3 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct 4 | torch_dtype: bfloat16 5 | use_flash_attention_2: true 6 | 7 | # Data training arguments 8 | # For definitions, see: src/h4/training/config.py 9 | dataset_mixer: 10 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_select/: 1.0 11 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0 12 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_join/: 1.0 13 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_order/: 1.0 14 | 15 | 16 | dataset_splits: 17 | - train_dpo 18 | - test_dpo 19 | preprocessing_num_workers: 12 20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 21 | report_to: ["tensorboard"] 22 | 23 | # DPOTrainer arguments 24 | bf16: true 25 | beta: 0.25 26 | do_eval: true 27 | eval_strategy: "no" 28 | eval_steps: 40 29 | gradient_accumulation_steps: 32 30 | gradient_checkpointing: true 31 | gradient_checkpointing_kwargs: 32 | use_reentrant: False 33 | hub_model_id: th-dpo 34 | learning_rate: 5.0e-6 35 | log_level: info 36 | logging_steps: 10 37 | lr_scheduler_type: cosine 38 | max_length: 2600 39 | max_prompt_length: 2000 40 | num_train_epochs: 3 41 | optim: adamw_torch 42 | output_dir: output/orpo-llama-3-validator-bird 43 | per_device_train_batch_size: 1 44 | per_device_eval_batch_size: 1 45 | push_to_hub: false 46 | save_strategy: "steps" 47 | save_steps: 40 48 | save_total_limit: 1 49 | seed: 42 50 | warmup_ratio: 0.1 51 | -------------------------------------------------------------------------------- /visualization/parameter_sensitivity_num_candidates.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | # Data for plotting 5 | candidates = np.array([1, 5, 10, 20, 30]) 6 | achieved_accuracy = np.array([59.17, 63.75, 64.73, 62.91, 62.78]) 7 | upper_bound = np.array([59.17, 69.6, 72.7, 76, 77.8]) 8 | lower_bound = np.array([59.17, 59.17, 59.17, 59.17, 59.17]) 9 | 10 | # Define figure size to fit a research paper column 11 | fig_width = 4 # Adjusted for single-column fit 12 | fig_height = fig_width * 0.75 # Maintain aspect ratio 13 | 14 | # Create the figure 15 | plt.figure(figsize=(fig_width, fig_height), dpi=300) # High DPI for publication quality 16 | 17 | # Plot the curves 18 | plt.plot(candidates, upper_bound, label="Upper Bound", linestyle="--", color="red") 19 | plt.plot(candidates, achieved_accuracy, label="MATS", marker="o", color="blue") 20 | plt.plot(candidates, lower_bound, label="Lower Bound", linestyle="--", color="green") 21 | 22 | # Fill between (shading) without adding legend 23 | plt.fill_between(candidates, achieved_accuracy, upper_bound, color="gray", alpha=0.3) 24 | plt.fill_between(candidates, lower_bound, achieved_accuracy, color="gray", alpha=0.3) 25 | 26 | # Labels and formatting 27 | plt.xlabel("Number of Candidates", fontsize=10) 28 | plt.ylabel(r"EX\%", fontsize=10) # LaTeX-style notation for EX% 29 | plt.xticks(candidates, fontsize=9) 30 | plt.yticks(fontsize=9) 31 | plt.legend(fontsize=9, loc="upper left") # Move legend to top left 32 | plt.grid(True, linewidth=0.5) 33 | 34 | # Remove unnecessary borders for a cleaner publication look 35 | plt.gca().spines["top"].set_visible(False) 36 | plt.gca().spines["right"].set_visible(False) 37 | 38 | # Save the figure in a high-quality format for LaTeX insertion 39 | plt.savefig("accuracy_vs_bounds.pdf", bbox_inches="tight", format="pdf") 40 | 41 | # Show the figure 42 | plt.show() 43 | -------------------------------------------------------------------------------- /visualization/sql_characteristic_bird.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | 5 | # Define the dataset with all extracted EX (%) values including MATS (Ours) 6 | data = { 7 | "Subset": [ 8 | "w/o JOIN", "w/ JOIN", "w/o Subquery", "w/ Subquery", 9 | "w/o Logical\nConnector", "w/ Logical\nConnector", 10 | "w/o ORDER-BY", "w/ ORDER-BY", "Overall" 11 | ], 12 | "MATS (Ours)": [68.02, 59.21, 63.12, 40.71, 65.33, 56.04, 63.67, 52.75, 64.74], 13 | "DAILSQL(SC)": [61.4, 53.9, 56.9, 37.9, 59.1, 51.3, 58.3, 46.3, 55.9], 14 | # "DAILSQL": [60.4, 52.2, 55.4, 35.6, 56.6, 51.0, 57.1, 43.0, 54.3], 15 | # "C3SQL": [55.3, 48.4, 51.2, 33.3, 54.5, 44.1, 53.5, 37.2, 50.2], 16 | "CodeS-15B": [63.5, 56.8, 59.8, 36.8, 62.2, 53.2, 61.1, 48.2, 58.5], 17 | "CodeS-7B": [63.2, 54.8, 58.5, 32.2, 60.6, 51.8, 59.6, 46.6, 57.0], 18 | # "SFT CodeS-3B": [61.2, 52.7, 56.3, 31.0, 59.8, 48.0, 57.9, 43.0, 54.9], 19 | # "SFT CodeS-1B": [57.1, 47.9, 51.6, 28.7, 55.3, 43.2, 53.4, 37.9, 50.3], 20 | "REDSQL-3B": [52.0, 41.1, 44.7, 31.0, 49.4, 36.3, 47.2, 31.1, 43.9], 21 | "REDSQL-L Large": [45.9, 36.1, 39.6, 21.8, 45.7, 28.6, 41.9, 25.6, 38.6], 22 | "REDSQL-L Base": [40.9, 30.4, 33.9, 19.5, 38.7, 25.3, 35.5, 23.6, 33.1], 23 | 24 | } 25 | 26 | # Convert to DataFrame and set index 27 | df = pd.DataFrame(data) 28 | df.set_index("Subset", inplace=True) 29 | 30 | # Create a figure with a single heatmap 31 | fig = plt.figure(figsize=(4.5, 3.5)) 32 | 33 | # Create the heatmap 34 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False) 35 | 36 | plt.set_xlabel("") 37 | plt.set_ylabel("Subset", fontsize=8) 38 | plt.set_xticklabels(plt.get_xticklabels(), rotation=90, ha="right", fontsize=6) 39 | plt.set_yticklabels(plt.get_yticklabels(), fontsize=6) 40 | 41 | # Adjust layout 42 | plt.tight_layout() 43 | 44 | # Show the plot 45 | plt.show() 46 | -------------------------------------------------------------------------------- /alignment-handbook/tests/test_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import unittest 17 | 18 | from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig 19 | 20 | 21 | class H4ArgumentParserTest(unittest.TestCase): 22 | def setUp(self): 23 | self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) 24 | self.yaml_file_path = "tests/fixtures/config_sft_full.yaml" 25 | 26 | def test_load_yaml(self): 27 | model_args, data_args, training_args = self.parser.parse_yaml_file(os.path.abspath(self.yaml_file_path)) 28 | self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1") 29 | 30 | def test_load_yaml_and_args(self): 31 | command_line_args = [ 32 | "--model_name_or_path=test", 33 | "--use_peft=true", 34 | "--lora_r=16", 35 | "--lora_dropout=0.5", 36 | ] 37 | model_args, data_args, training_args = self.parser.parse_yaml_and_args( 38 | os.path.abspath(self.yaml_file_path), command_line_args 39 | ) 40 | self.assertEqual(model_args.model_name_or_path, "test") 41 | self.assertEqual(model_args.use_peft, True) 42 | self.assertEqual(model_args.lora_r, 16) 43 | self.assertEqual(model_args.lora_dropout, 0.5) 44 | -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-3b-bird/orpo-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-3b-bird-validator-fft 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_select/: 1.0 10 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_condition/: 1.0 11 | # ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_join/: 1.0 12 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_select/: 1.0 13 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0 14 | # ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_join/: 1.0 15 | 16 | dataset_splits: 17 | - train_dpo 18 | - test_dpo 19 | preprocessing_num_workers: 12 20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 21 | report_to: ["tensorboard"] 22 | 23 | # DPOTrainer arguments 24 | bf16: true 25 | beta: 1.0 26 | do_eval: true 27 | eval_strategy: "steps" 28 | eval_steps: 100 29 | gradient_accumulation_steps: 8 30 | gradient_checkpointing: true 31 | gradient_checkpointing_kwargs: 32 | use_reentrant: False 33 | learning_rate: 8.0e-6 34 | log_level: info 35 | logging_steps: 10 36 | lr_scheduler_type: inverse_sqrt 37 | max_length: 2600 38 | max_prompt_length: 2000 39 | num_train_epochs: 1 40 | max_steps: -1 41 | optim: adamw_torch 42 | output_dir: output/reproduce/orpo-llama-3-validator-bird/ 43 | per_device_train_batch_size: 1 44 | per_device_eval_batch_size: 1 45 | push_to_hub: false 46 | save_strategy: "steps" 47 | save_steps: 100 48 | save_total_limit: 1 49 | seed: 42 50 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /alignment-handbook/recipes/llama-1b-bird/orpo-validator.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: ./output/llama-1b-bird-validator-fft 3 | torch_dtype: bfloat16 4 | use_flash_attention_2: true 5 | 6 | # Data training arguments 7 | # For definitions, see: src/h4/training/config.py 8 | dataset_mixer: 9 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_select/: 1.0 10 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_condition/: 1.0 11 | # ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_join/: 1.0 12 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_select/: 1.0 13 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0 14 | # ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_join/: 1.0 15 | 16 | dataset_splits: 17 | - train_dpo 18 | - test_dpo 19 | preprocessing_num_workers: 12 20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}" 21 | report_to: ["tensorboard"] 22 | 23 | # DPOTrainer arguments 24 | bf16: true 25 | beta: 1.0 26 | do_eval: true 27 | eval_strategy: "steps" 28 | eval_steps: 100 29 | gradient_accumulation_steps: 8 30 | gradient_checkpointing: true 31 | gradient_checkpointing_kwargs: 32 | use_reentrant: False 33 | learning_rate: 8.0e-6 34 | log_level: info 35 | logging_steps: 10 36 | lr_scheduler_type: inverse_sqrt 37 | max_length: 2600 38 | max_prompt_length: 2000 39 | num_train_epochs: -1 40 | max_steps: 600 41 | optim: adamw_torch 42 | output_dir: output/reproduce/orpo-llama-1b-validator-bird/ 43 | per_device_train_batch_size: 1 44 | per_device_eval_batch_size: 1 45 | push_to_hub: false 46 | save_strategy: "steps" 47 | save_steps: 100 48 | save_total_limit: 1 49 | seed: 42 50 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /scripts/evaluate_bird.sh: -------------------------------------------------------------------------------- 1 | 2 | # CUDA_VISIBLE_DEVICES=0 vllm serve orpo-llama-3b-iter-2-bird-planner-no-filter/ --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.9 --enable-prefix-caching 3 | 4 | # CUDA_VISIBLE_DEVICES=1 vllm serve llama-1b-bird-validator-fft/ --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.5 --enable-prefix-caching 5 | 6 | # CUDA_VISIBLE_DEVICES=1 vllm serve llama-1b-bird-fixed-fft/ --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.4 --enable-prefix-caching 7 | 8 | 9 | mkdir logs 10 | mkdir logs/bird-dev 11 | # # Evaluate 12 | n=10 13 | temperature=1.0 14 | seed=100 15 | eval_file=data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.jsonl 16 | rm $eval_file 17 | PYTHONPATH=. python evaluate_end2end.py \ 18 | --input_file data/full_value_matching_schema_insight_bird_062024_with_evidence_dev_text2sql.json \ 19 | --output_file $eval_file \ 20 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --only_planner --seed 100 --n_processes 32 21 | 22 | python jsonl2json.py --jsonl-file data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.jsonl 23 | 24 | # Evaluate Selection 25 | eval_file=data/evaluate/orpo-llama-3-iter-2-selection-bird_dev-greedy-and-sampling-seed$seed.jsonl 26 | rm $eval_file 27 | PYTHONPATH=. python evaluate_end2end.py \ 28 | --input_file data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.json \ 29 | --output_file $eval_file \ 30 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 \ 31 | --skip_planner --seed 100 --n_processes 8 --skip_validator_join 32 | 33 | python compute_acc.py --pred_file $eval_file 34 | 35 | -------------------------------------------------------------------------------- /utils/classifier_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # CrossEntropyLoss = softmax + log + NLLLoss 6 | 7 | class FocalLoss(nn.Module): 8 | def __init__(self, weight=None, gamma=0.5, reduction=None): 9 | super(FocalLoss, self).__init__() 10 | 11 | self.weight = weight 12 | self.gamma = gamma 13 | self.reduction = reduction 14 | 15 | def forward(self, input_tensor, target_tensor): 16 | assert input_tensor.shape[0] == target_tensor.shape[0] 17 | 18 | prob = F.softmax(input_tensor, dim = -1) 19 | log_prob = torch.log(prob + 1e-8) 20 | 21 | loss = F.nll_loss( 22 | ((1 - prob) ** self.gamma) * log_prob, 23 | target_tensor, 24 | weight=self.weight, 25 | reduction=self.reduction 26 | ) 27 | 28 | return loss 29 | 30 | class ClassifierLoss(): 31 | def __init__(self, alpha, gamma): 32 | weight = torch.FloatTensor([1-alpha, alpha]) 33 | if torch.cuda.is_available(): 34 | weight = weight.cuda() 35 | 36 | self.focal_loss = FocalLoss( 37 | weight = weight, 38 | gamma = gamma, 39 | reduction = 'mean' 40 | ) 41 | 42 | # self.ce_loss = nn.CrossEntropyLoss(weight = weight, reduction = "mean") 43 | 44 | def compute_batch_loss(self, batch_logits, batch_labels, batch_size): 45 | loss = 0 46 | for logits, labels in zip(batch_logits, batch_labels): 47 | loss += self.focal_loss(logits, labels) 48 | 49 | return loss/batch_size 50 | 51 | def compute_loss( 52 | self, 53 | batch_table_name_cls_logits, 54 | batch_table_labels, 55 | batch_column_info_cls_logits, 56 | batch_column_labels 57 | ): 58 | batch_size = len(batch_table_labels) 59 | 60 | table_loss = self.compute_batch_loss(batch_table_name_cls_logits, batch_table_labels, batch_size) 61 | column_loss = self.compute_batch_loss(batch_column_info_cls_logits, batch_column_labels, batch_size) 62 | 63 | return table_loss + column_loss -------------------------------------------------------------------------------- /data_processing/generate_validator_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | import multiprocessing.pool 4 | import functools 5 | from tqdm import tqdm 6 | import pandas as pd 7 | from validator import ValidatorJOIN, _execute_sql, _make_str_response, is_execution_correct 8 | import argparse 9 | import os 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json') 13 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl') 14 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai']) 15 | args = parser.parse_args() 16 | 17 | data = json.load(open(args.input_file)) 18 | 19 | if os.path.exists(args.output_file): 20 | old_output = json.load(open(args.output_file)) 21 | data[:len(old_output)] = old_output 22 | else: 23 | old_output = [] 24 | 25 | # open jsonl file for append contents 26 | output_file = open(args.output_file, 'a+') 27 | 28 | validator = ValidatorJOIN(endpoint_type=args.endpoint_type) 29 | 30 | for isample in tqdm(range(0, len(data)), total=len(data)): 31 | sample = data[isample] 32 | 33 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql']) 34 | 35 | sql = sample['predict_sql'] 36 | 37 | answer, execution_result = validator.validate(sample) 38 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 39 | 40 | print("-"*20) 41 | print("Is correct: ", is_correct) 42 | print(answer) 43 | 44 | sample['is_correct'] = is_correct 45 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer 46 | sample['validator_join'] = answer 47 | 48 | sample['true_result'] = _make_str_response(*true_execution_result) 49 | sample['pred_result'] = _make_str_response(*execution_result) 50 | 51 | del sample['table_labels'] 52 | del sample['column_labels'] 53 | del sample['schema'] 54 | del sample['matched_contents'] 55 | 56 | # json.dump(data[:isample+1], open(args.output_file, 'w+'), ensure_ascii=False, indent=4) 57 | # write new sample in jsonl file 58 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 59 | -------------------------------------------------------------------------------- /scripts/evaluate_dr_spider.sh: -------------------------------------------------------------------------------- 1 | # n=9 2 | # temperature=1.0 3 | # seed=100 4 | # mkdir logs/ 5 | # mkdir logs/dr-spider 6 | # for test_set in DB_schema_synonym DB_DBcontent_equivalence DB_schema_abbreviation NLQ_column_value SQL_DB_text SQL_DB_number SQL_comparison NLQ_column_carrier NLQ_multitype NLQ_others NLQ_keyword_synonym SQL_NonDB_number NLQ_column_synonym NLQ_keyword_carrier NLQ_value_synonym NLQ_column_attribute SQL_sort_order 7 | # do 8 | # eval_file=data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.jsonl 9 | # # rm $eval_file 10 | # PYTHONPATH=. python evaluate_end2end.py \ 11 | # --input_file data/schema_insight_dr_spider_text2sql_${test_set}.json \ 12 | # --output_file $eval_file \ 13 | # --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --only_planner --seed $seed --n_processes 32 14 | 15 | # rm progress_dr.pkl 16 | # python -u check_correct_recall.py --pred_file $eval_file --progress_file progress_dr.pkl > logs/dr-spider/log-orpo-planner-dr_spider-$test_set-n$n-temp$temperature-seed$seed.txt 17 | # done 18 | 19 | n=9 20 | temperature=1.0 21 | seed=100 22 | mkdir logs/ 23 | mkdir logs/dr-spider 24 | for test_set in SQL_comparison NLQ_column_carrier NLQ_multitype NLQ_others NLQ_keyword_synonym SQL_NonDB_number NLQ_column_synonym NLQ_keyword_carrier NLQ_value_synonym NLQ_column_attribute SQL_sort_order 25 | do 26 | python jsonl2json.py --jsonl-file data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.jsonl 27 | 28 | eval_file=data/evaluate/orpo-llama-3-iter-3-end2end-dr_spider_${test_set}.jsonl 29 | # rm $eval_file 30 | PYTHONPATH=. python evaluate_end2end.py \ 31 | --input_file data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.json \ 32 | --output_file $eval_file \ 33 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --skip_planner --seed $seed --n_processes 32 --skip_validator_join --skip_selection 34 | 35 | rm progress_dr.pkl 36 | python -u check_correct_recall.py --pred_file $eval_file --progress_file progress_dr.pkl > logs/dr-spider/log-orpo-llama-3-iter-3-end2end-dr_spider-$test_set-n$n-temp$temperature-seed$seed.txt 37 | mv results_spider.pkl logs/results-dr-spider-end2end-$test_set.pkl 38 | done 39 | 40 | -------------------------------------------------------------------------------- /visualization/domain_knowledge.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | 5 | # Sample dataset (Ensure to replace with actual data) 6 | data = { 7 | "Method": ["MATS (Ours)", "DAILSQL(SC)", "CodeS-15B", "CodeS-7B", "REDSQL-3B\n+NatSQL", "REDSQL-3B", "Graphix\n+PICARD"], 8 | "College": [84.0, 79.6, 82.4, 83.3, 80.6, 83.3, 78.7], 9 | "Competition": [92.0, 79.0, 85.5, 82.3, 80.6, 83.9, 82.3], 10 | "Geography": [71.0, 76.7, 75.0, 75.8, 52.5, 65.0, 64.2], 11 | "Social": [95.0, 83.9, 83.9, 82.1, 76.8, 80.4, 82.1], 12 | "Transportation": [97.0, 85.0, 88.8, 87.5, 86.3, 80.0, 98.8], 13 | "Overall": [87.1, 83.6, 84.9, 85.4, 84.1, 81.8, 80.9] 14 | } 15 | 16 | # DB Count Data (Ensure to replace with actual data) 17 | db_count = {"College": 10, "Competition": 5, "Geography": 3, "Social": 2, "Transportation": 12} 18 | 19 | # Convert to DataFrame 20 | df = pd.DataFrame(data) 21 | df.set_index("Method", inplace=True) 22 | 23 | # Transpose DataFrame to swap axes 24 | df = df.T 25 | 26 | # Create a figure with two side-by-side subplots, adjusting colors and layout 27 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3.5), gridspec_kw={'width_ratios': [4, 0.8]}) 28 | 29 | # Create the heatmap in the first subplot with a colormap similar to the reference image 30 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False, ax=axes[0]) 31 | axes[0].set_xlabel("", fontsize=8) 32 | axes[0].set_ylabel("DB Domain", fontsize=8) 33 | axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=90, ha="right", fontsize=6) 34 | axes[0].set_yticklabels(axes[0].get_yticklabels(), fontsize=6) 35 | 36 | # Create the DB count bar plot in the second subplot 37 | domains = list(db_count.keys()) 38 | db_values = list(db_count.values()) 39 | axes[1].barh(domains, db_values, color="#1f77b4", alpha=0.8) # Adjusted to a similar blue tone 40 | axes[1].set_xlabel("#DB Count", fontsize=6) # Reduce x-axis title size 41 | axes[1].set_yticklabels([]) # Remove y-axis ticks 42 | axes[1].set_xticks(range(0, max(db_values) + 1, max(2, max(db_values) // 4))) # Keep sparse x-axis ticks 43 | axes[1].tick_params(axis='x', labelsize=6) # Reduce x-axis tick label size 44 | 45 | # Adjust layout for better fitting 46 | plt.tight_layout() 47 | 48 | # Show the plot 49 | plt.show() -------------------------------------------------------------------------------- /visualization/rlef_improvement.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Data for Spider Dev and BIRD Dev 4 | steps = [0, 1, 2, 3] # RLEF iterations 5 | labels = ['SFT', 'Iter 1', 'Iter 2', 'Iter 3'] 6 | 7 | # MATS 3B Planner Data 8 | spider_3b = [83.3, 84, 85.4, 85.2] 9 | bird_3b = [53.65, 56.32, 59.32, 58.6] 10 | 11 | # MATS Data 12 | spider_mats = [85.5, 86.3, 87.1, 87] 13 | bird_mats = [59.06, 60.82, 64.73, 62.58] 14 | 15 | # Reference model performance for horizontal lines 16 | ref_lines = { 17 | "GPT-4": {"color": "cyan", "style": "-.", "spider": 76.8, "bird": 49.15}, 18 | "CodeS-15B": {"color": "red", "style": ":", "spider": 84.9, "bird": 58.47}, 19 | "MAC-SQL + GPT-4": {"color": "purple", "style": "--", "spider": 86.75, "bird": 59.59}, 20 | "DIN-SQL + GPT-4": {"color": "green", "style": "--", "spider": 82.8, "bird": 50.72} 21 | } 22 | 23 | # Adjust figure size for a single-column layout 24 | fig, axes = plt.subplots(2, 1, figsize=(3.5, 5), sharex=True) 25 | 26 | # Spider Dev subplot 27 | axes[0].plot(steps, spider_3b, marker='o', color='blue', label='MATS 3B Planner', linewidth=1.5) 28 | axes[0].plot(steps, spider_mats, marker='s', color='red', label='MATS', linewidth=1.5) 29 | for label, data in ref_lines.items(): 30 | axes[0].axhline(y=data["spider"], color=data["color"], linestyle=data["style"], label=label, linewidth=1) 31 | axes[0].set_title('Spider Dev', fontsize=8, pad=8) 32 | axes[0].grid(alpha=0.3) 33 | 34 | # BIRD Dev subplot 35 | axes[1].plot(steps, bird_3b, marker='o', color='blue', linewidth=1.5) 36 | axes[1].plot(steps, bird_mats, marker='s', color='red', linewidth=1.5) 37 | for label, data in ref_lines.items(): 38 | axes[1].axhline(y=data["bird"], color=data["color"], linestyle=data["style"], linewidth=1) 39 | axes[1].set_title('BIRD Dev', fontsize=8, pad=8) 40 | axes[1].set_xticks(steps) 41 | axes[1].set_xticklabels(labels, rotation=45, fontsize=8) 42 | axes[1].grid(alpha=0.3) 43 | 44 | # Remove axis titles (labels) 45 | axes[0].set_ylabel('') 46 | axes[1].set_ylabel('') 47 | axes[1].set_xlabel('') 48 | 49 | # Add a single legend outside the subplots 50 | handles, labels = axes[0].get_legend_handles_labels() 51 | fig.legend(handles, labels, loc='lower center', fontsize=7, ncol=2, frameon=False, bbox_to_anchor=(0.5, -0.09)) 52 | 53 | # Adjust layout for compactness 54 | plt.tight_layout(pad=1.0) 55 | plt.show() 56 | -------------------------------------------------------------------------------- /visualization/.ipynb_checkpoints/rlef_improvement-checkpoint.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Data for Spider Dev and BIRD Dev 4 | steps = [0, 1, 2, 3] # RLEF iterations 5 | labels = ['SFT', 'Iter 1', 'Iter 2', 'Iter 3'] 6 | 7 | # MATS 3B Planner Data 8 | spider_3b = [83.3, 84, 85.4, 85.2] 9 | bird_3b = [53.65, 56.32, 59.32, 58.6] 10 | 11 | # MATS Data 12 | spider_mats = [85.5, 86.3, 87.1, 87] 13 | bird_mats = [59.06, 60.82, 64.73, 62.58] 14 | 15 | # Reference model performance for horizontal lines 16 | ref_lines = { 17 | "GPT-4": {"color": "cyan", "style": "-.", "spider": 76.8, "bird": 49.15}, 18 | "CodeS-15B": {"color": "red", "style": ":", "spider": 84.9, "bird": 58.47}, 19 | "MAC-SQL + GPT-4": {"color": "purple", "style": "--", "spider": 86.75, "bird": 59.59}, 20 | "DIN-SQL + GPT-4": {"color": "green", "style": "--", "spider": 82.8, "bird": 50.72} 21 | } 22 | 23 | # Adjust figure size for a single-column layout 24 | fig, axes = plt.subplots(2, 1, figsize=(3.5, 5), sharex=True) 25 | 26 | # Spider Dev subplot 27 | axes[0].plot(steps, spider_3b, marker='o', color='blue', label='MATS 3B Planner', linewidth=1.5) 28 | axes[0].plot(steps, spider_mats, marker='s', color='red', label='MATS', linewidth=1.5) 29 | for label, data in ref_lines.items(): 30 | axes[0].axhline(y=data["spider"], color=data["color"], linestyle=data["style"], label=label, linewidth=1) 31 | axes[0].set_title('Spider Dev', fontsize=10, pad=8) 32 | axes[0].grid(alpha=0.3) 33 | 34 | # BIRD Dev subplot 35 | axes[1].plot(steps, bird_3b, marker='o', color='blue', linewidth=1.5) 36 | axes[1].plot(steps, bird_mats, marker='s', color='red', linewidth=1.5) 37 | for label, data in ref_lines.items(): 38 | axes[1].axhline(y=data["bird"], color=data["color"], linestyle=data["style"], linewidth=1) 39 | axes[1].set_title('BIRD Dev', fontsize=10, pad=8) 40 | axes[1].set_xticks(steps) 41 | axes[1].set_xticklabels(labels, rotation=45, fontsize=8) 42 | axes[1].grid(alpha=0.3) 43 | 44 | # Remove axis titles (labels) 45 | axes[0].set_ylabel('') 46 | axes[1].set_ylabel('') 47 | axes[1].set_xlabel('') 48 | 49 | # Add a single legend outside the subplots 50 | handles, labels = axes[0].get_legend_handles_labels() 51 | fig.legend(handles, labels, loc='lower center', fontsize=7, ncol=2, frameon=False, bbox_to_anchor=(0.5, -0.1)) 52 | 53 | # Adjust layout for compactness 54 | plt.tight_layout(pad=1.0) 55 | plt.show() 56 | -------------------------------------------------------------------------------- /validator_data/generate_validator_order_using_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | from validator_data.validator import ValidatorOrder, _execute_sql, _make_str_response, is_execution_correct 5 | import re 6 | import argparse 7 | 8 | # add parse for input data file (train, dev) and output_file 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json') 11 | parser.add_argument('--output_file', type=str, default='bird_validator_select.jsonl') 12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai']) 13 | args = parser.parse_args() 14 | 15 | data = [] 16 | with open(args.input_file) as fp: 17 | for line in fp: 18 | data.append(json.loads(line)) 19 | 20 | # load saved output file if exists 21 | if os.path.exists(args.output_file): 22 | old_output = [] 23 | with open(args.output_file, 'r') as f: 24 | for line in f: 25 | old_output.append(json.loads(line)) 26 | data[:len(old_output)] = old_output 27 | else: 28 | old_output = [] 29 | 30 | # open jsonl file for append contents 31 | output_file = open(args.output_file, 'a+') 32 | 33 | validator = ValidatorOrder(endpoint_type=args.endpoint_type) 34 | 35 | for isample in tqdm(range(len(old_output), len(data)), total=len(data) - len(old_output)): 36 | sample = data[isample] 37 | 38 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql']) 39 | 40 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planner'], re.DOTALL) 41 | if pred_sql_match is None: continue 42 | 43 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip() 44 | sample['predict_sql'] = pred_sql 45 | 46 | answer, execution_result = validator.validate(sample) 47 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 48 | 49 | print("-"*20) 50 | print("Is correct: ", is_correct) 51 | print(answer) 52 | 53 | sample['is_correct'] = is_correct 54 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer 55 | sample['validator_order'] = answer 56 | 57 | sample['true_result'] = _make_str_response(*true_execution_result) 58 | sample['pred_result'] = _make_str_response(*execution_result) 59 | 60 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 61 | -------------------------------------------------------------------------------- /bird_evaluation/run_evaluation.sh: -------------------------------------------------------------------------------- 1 | db_root_path='data/bird-062024/dev/dev_databases/' 2 | data_mode='dev' 3 | diff_json_path='data/bird-062024/dev/dev.json' 4 | predicted_sql_path=$1 5 | ground_truth_path='data/bird-062024/dev/' 6 | num_cpus=16 7 | meta_time_out=30.0 8 | mode_gt='gt' 9 | mode_predict='gpt' 10 | 11 | # db_root_path='./data/sft_data_collections/bird/dev/dev_databases/' 12 | # data_mode='dev' 13 | # diff_json_path='./data/sft_data_collections/bird/dev/dev.json' 14 | # # predicted_sql_path=$1 15 | # ground_truth_path='./data/sft_data_collections/bird/dev/' 16 | # num_cpus=16 17 | # meta_time_out=30.0 18 | # mode_gt='gt' 19 | # mode_predict='gpt' 20 | 21 | echo '''starting to compare with knowledge for ex''' 22 | python3 -u ./bird_evaluation/evaluation.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} \ 23 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \ 24 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} 25 | 26 | echo '''starting to compare with knowledge for ves''' 27 | python3 -u ./bird_evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} 28 | 29 | 30 | 31 | # db_root_path='./data/dev/dev_databases/' 32 | # data_mode='dev' 33 | # diff_json_path='./data/dev/dev.json' 34 | # predicted_sql_path=$1 35 | # ground_truth_path='./data/dev/' 36 | # num_cpus=16 37 | # meta_time_out=30.0 38 | # mode_gt='gt' 39 | # mode_predict='gpt' 40 | 41 | # echo '''starting to compare with knowledge for ex''' 42 | # python3 -u ./bird_evaluation/evaluation_062024.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \ 43 | # --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \ 44 | # --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} 45 | 46 | # echo '''starting to compare with knowledge for ves''' 47 | # python3 -u ./bird_evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} 48 | -------------------------------------------------------------------------------- /utils/load_classifier_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import itertools 3 | from torch.utils.data import Dataset 4 | 5 | class SchemaItemClassifierDataset(Dataset): 6 | def __init__(self, dataset_dir): 7 | super(SchemaItemClassifierDataset, self).__init__() 8 | 9 | self.texts: list[str] = [] 10 | self.all_column_names: list[list[list[str]]] = [] 11 | self.all_column_labels: list[list[list[int]]] = [] 12 | self.all_table_names: list[list[str]] = [] 13 | self.all_table_labels: list[list[int]] = [] 14 | 15 | dataset = json.load(open(dataset_dir)) 16 | 17 | assert type(dataset) == list 18 | 19 | for data in dataset: 20 | table_names_in_one_db = [] 21 | column_names_in_one_db = [] 22 | 23 | for table in data["schema"]["schema_items"]: 24 | # table_names_in_one_db.append(table["table_name"]) 25 | # column_names_in_one_db.append(table["column_names"]) 26 | table_names_in_one_db.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ 27 | if table["table_comment"] != "" else table["table_name"]) 28 | column_names_in_one_db.append([column_name + " ( " + column_comment + " ) " \ 29 | if column_comment != "" else column_name \ 30 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) 31 | 32 | self.texts.append(data["text"]) 33 | self.all_table_names.append(table_names_in_one_db) 34 | self.all_column_names.append(column_names_in_one_db) 35 | self.all_table_labels.append(data["table_labels"]) 36 | self.all_column_labels.append(list(itertools.chain(*data["column_labels"]))) 37 | 38 | def __len__(self): 39 | return len(self.texts) 40 | 41 | def __getitem__(self, index): 42 | text = self.texts[index] 43 | table_names_in_one_db = self.all_table_names[index] 44 | table_labels_in_one_db = self.all_table_labels[index] 45 | column_infos_in_one_db = self.all_column_names[index] 46 | column_labels_in_one_db = self.all_column_labels[index] 47 | 48 | return { 49 | "text": text, 50 | "table_names_in_one_db": table_names_in_one_db, 51 | "table_labels_in_one_db": table_labels_in_one_db, 52 | "column_infos_in_one_db": column_infos_in_one_db, 53 | "column_labels_in_one_db": column_labels_in_one_db 54 | } 55 | -------------------------------------------------------------------------------- /data_processing/prompts/zero_shot_prompt_planner.txt: -------------------------------------------------------------------------------- 1 | database schema : 2 | table pages , columns = [ pages.words ( integer | values : 1081 , 68 ) , pages.page ( integer | values : 1 , 2 ) , pages.pid ( integer | primary key | comment : page id | values : 1 , 2 ) , pages.title ( text | values : Àbac , Abadia ) , pages.lid ( integer | comment : language id | values : 1 ) , pages.revision ( integer | values : 28236978 , 24086480 ) ] 3 | table words , columns = [ words.word ( text | values : +,2 , +,33 ) , words.wid ( integer | primary key | comment : word id | values : 2148990 , 2506463 ) , words.occurrences ( integer | values : 242 , 16841 ) ] 4 | table langs , columns = [ langs.pages ( integer | values : 1129144 ) , langs.words ( integer | values : 2764996 ) , langs.lid ( integer | primary key | comment : language id | values : 1 ) , langs.lang ( text | comment : language | values : ca ) , langs.locale ( text | values : ca_ES ) ] 5 | table pages_words , columns = [ pages_words.pid ( integer | primary key | comment : page id | values : 1 , 2 ) , pages_words.wid ( integer | primary key | comment : word id | values : 1 , 2 ) , pages_words.occurrences ( integer | values : 30 , 8 ) ] 6 | table langs_words , columns = [ langs_words.wid ( integer | primary key | comment : word id | values : 1 , 2 ) , langs_words.occurrences ( integer | values : 242 , 16841 ) , langs_words.lid ( integer | primary key | comment : language id | values : 1 ) ] 7 | table biwords , columns = [ biwords.occurrences ( integer | values : 4 , 3 ) , biwords.lid ( integer | primary key | comment : language id | values : 1 ) , biwords.w1st ( integer | primary key | comment : word id of the first word | values : 1 , 2 ) , biwords.w2nd ( integer | primary key | comment : word id of the second word | values : 2 , 4 ) ] 8 | foreign keys : 9 | pages.lid = langs.lid 10 | langs_words.wid = words.wid 11 | langs_words.lid = langs.lid 12 | pages_words.wid = words.wid 13 | pages_words.pid = pages.pid 14 | biwords.w2nd = words.wid 15 | biwords.w1st = words.wid 16 | biwords.lid = langs.lid 17 | 18 | matched contents : 19 | pages.words ( 1500 ) 20 | pages.page ( 1500 ) 21 | pages.pid ( 1500 ) 22 | pages.title ( Pages , 1500 ) 23 | words.word ( pages , words , calculates , differents , divides , percentages , counts , page's , wordes ) 24 | pages_words.occurrences ( 1500 ) 25 | langs_words.wid ( 1500 ) 26 | langs_words.occurrences ( 1500 ) 27 | biwords.occurrences ( 1500 ) 28 | biwords.w1st ( 1500 ) 29 | biwords.w2nd ( 1500 ) 30 | 31 | Question: DIVIDE(COUNT(pages WHERE words = 1500), COUNT(pages)) as percentage; Calculate the percentage of pages that have 1500 different words. 32 | 33 | Planning: -------------------------------------------------------------------------------- /validator_data/generate_validator_select_using_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | from validator_data.validator import ValidatorSelect, _execute_sql, _make_str_response, is_execution_correct 5 | import argparse 6 | import re 7 | 8 | # add parse for input data file (train, dev) and output_file 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--input_file', type=str, default='./data/evaluate/phi-3-planner_combine_bird_062024_with_evidence_train.jsonl') 11 | parser.add_argument('--output_file', type=str, default='bird_validator_select.jsonl') 12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', 13 | choices=['vllm', 'llamacpp', 'openai']) 14 | args = parser.parse_args() 15 | 16 | data = [] 17 | with open(args.input_file) as fp: 18 | for line in fp: 19 | data.append(json.loads(line)) 20 | 21 | # load saved output file if exists 22 | if os.path.exists(args.output_file): 23 | old_output = [] 24 | with open(args.output_file, 'r') as f: 25 | for line in f: 26 | old_output.append(json.loads(line)) 27 | data[:len(old_output)] = old_output 28 | else: 29 | old_output = [] 30 | 31 | # open jsonl file for append contents 32 | output_file = open(args.output_file, 'a+') 33 | 34 | validator = ValidatorSelect(endpoint_type=args.endpoint_type) 35 | 36 | for isample in tqdm(range(len(old_output), len(data)), total=len(data)-len(old_output)): 37 | sample = data[isample] 38 | 39 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql']) 40 | 41 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planner'], re.DOTALL) 42 | if pred_sql_match is None: continue 43 | 44 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip() 45 | sample['predict_sql'] = pred_sql 46 | 47 | 48 | answer, execution_result = validator.validate(sample) 49 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 50 | 51 | print("-"*20) 52 | print("Is correct: ", is_correct) 53 | print(answer) 54 | 55 | sample['is_correct'] = is_correct 56 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer 57 | sample['validator_select'] = answer 58 | 59 | sample['true_result'] = _make_str_response(*true_execution_result) 60 | sample['pred_result'] = _make_str_response(*execution_result) 61 | 62 | # json.dump(data[:isample+1], open(args.output_file, 'w+'), ensure_ascii=False, indent=4) 63 | # write new sample in jsonl file 64 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 65 | -------------------------------------------------------------------------------- /validator_data/generate_fixed_sql_using_fewshot_condition.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | import os 4 | import multiprocessing.pool 5 | import functools 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from validator import FixAgent, _make_str_response, _execute_sql, is_execution_correct 9 | 10 | PROMPT = open('./few_shot_prompt_fix_join.txt').read().strip() + """ 11 | ========= 12 | {schema} 13 | 14 | Matched contents are written in this format table.column (some values can be found in that column) 15 | {matched_content} 16 | 17 | Question: {question} 18 | 19 | SQL query: {sql_query} 20 | 21 | Feedback:{feedback} 22 | 23 | FIXED SQL:""" 24 | 25 | data = [] 26 | with open('../data/llm_alignment/validator_condition_bird_with_evidence_dev.jsonl') as fp: 27 | for line in fp: 28 | data.append(json.loads(line)) 29 | 30 | output_file = './bird_fixed_sql_condition.jsonl' 31 | 32 | output_fp = open(output_file, 'w+') 33 | 34 | fix_agent = FixAgent(PROMPT, endpoint_type='vllm') 35 | 36 | for isample in tqdm(range(len(data)), total=len(data)): 37 | sample = data[isample] 38 | 39 | is_correct = sample['is_correct'] 40 | if sample['validator_condition'] is None or "Conclude: correct" in sample['validator_condition']: 41 | output_fp.write(json.dumps(sample) + '\n') 42 | continue 43 | 44 | prompt = PROMPT.format( 45 | schema=sample['schema_sequence'], 46 | matched_content=sample['content_sequence'], 47 | question=sample['text'], 48 | sql_query=sample['predict_sql'], 49 | # execution_response=sample['pred_result'], 50 | feedback=sample['validator_condition'] 51 | ) 52 | # print(prompt) 53 | answer = fix_agent.get_answer([{"role": "user", "content": prompt}]) 54 | execution_result = _execute_sql("../" + sample['db_path'], answer) 55 | 56 | print("-"*20) 57 | print(answer) 58 | # break 59 | sample['fixed_sql'] = answer 60 | sample['fixed_pred_result'] = _make_str_response(*execution_result) 61 | 62 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql']) 63 | is_fixed_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 64 | sample['is_fixed_correct'] = is_fixed_correct 65 | 66 | output_fp.write(json.dumps(sample) + '\n') 67 | 68 | bird_results_dict = dict() 69 | for idx, sample in enumerate(data): 70 | if 'fixed_sql' in sample: 71 | predicted_sql = sample['fixed_sql'] 72 | else: 73 | predicted_sql = sample['predict_sql'] 74 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"] 75 | with open("predict_dev.json", "w", encoding = 'utf-8') as f: 76 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False)) 77 | output_fp.close() 78 | 79 | -------------------------------------------------------------------------------- /visualization/lambda_sensitivity.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | 3 | # # Data 4 | # gamma_values = [0, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 1] 5 | # bird_dev_values = [54.95, 55.54, 56, 55.54, 56.39, 56.19, 57.11, 57.24, 57.24, 56.52, 56.32] 6 | # count_values = [14, 14, 16, 15, 14, 12, 10, 5, 8, 6, 8] 7 | 8 | # # Adjust figure size for single-column fit in a 2-column research paper 9 | # fig_width = 3.5 # Typical width for a single column in inches 10 | # fig_height = 2.5 # Adjusted height for readability 11 | 12 | # # Create figure and axis objects with adjusted size 13 | # fig, ax1 = plt.subplots(figsize=(fig_width, fig_height)) 14 | 15 | # # First Y-axis (EX%) 16 | # ax1.set_xlabel(r"$\lambda$", fontsize=10) # Using LaTeX formatting for lambda 17 | # ax1.set_ylabel("EX%", color="tab:blue", fontsize=10) 18 | # ax1.plot(gamma_values, bird_dev_values, marker="o", linestyle="-", color="tab:blue") 19 | # ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=8) 20 | # ax1.tick_params(axis="x", labelsize=8) 21 | 22 | # # Second Y-axis (No. Syntax Error) 23 | # ax2 = ax1.twinx() 24 | # ax2.set_ylabel("No. Syntax Error", color="tab:red", fontsize=10) 25 | # ax2.plot(gamma_values, count_values, marker="s", linestyle="--", color="tab:red") 26 | # ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=8) 27 | 28 | # # Grid 29 | # ax1.grid(True, linestyle="--", alpha=0.6) 30 | 31 | # # Adjust layout for better spacing 32 | # plt.tight_layout() 33 | 34 | # # Show plot 35 | # plt.show() 36 | 37 | 38 | import matplotlib.pyplot as plt 39 | 40 | # Data 41 | gamma_values = [0, 0.05, 0.1, 0.25, 0.5, 0.75, 1] 42 | bird_dev_values = [54.95, 56.19, 57.11, 57.24, 57.24, 56.52, 56.32] 43 | count_values = [14, 12, 10, 5, 8, 6, 8] 44 | 45 | # Adjust figure size for single-column fit in a 2-column research paper 46 | fig_width = 3.5 # Typical width for a single column in inches 47 | fig_height = 2.5 # Adjusted height for readability 48 | 49 | # Create figure and axis objects with adjusted size 50 | fig, ax1 = plt.subplots(figsize=(fig_width, fig_height)) 51 | 52 | # First Y-axis (EX%) 53 | ax1.set_xlabel(r"$\lambda$", fontsize=10) # Using LaTeX formatting for lambda 54 | ax1.set_ylabel("EX%", color="tab:blue", fontsize=10) 55 | ax1.plot(gamma_values, bird_dev_values, marker="o", linestyle="-", color="tab:blue") 56 | ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=8) 57 | ax1.tick_params(axis="x", labelsize=8) 58 | 59 | # Second Y-axis (No. Syntax Error) 60 | ax2 = ax1.twinx() 61 | ax2.set_ylabel("No. Syntax Error", color="tab:red", fontsize=10) 62 | ax2.plot(gamma_values, count_values, marker="s", linestyle="--", color="tab:red") 63 | ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=8) 64 | 65 | # Grid 66 | ax1.grid(True, linestyle="--", alpha=0.6) 67 | 68 | # Adjust layout for better spacing 69 | plt.tight_layout() 70 | 71 | # Show plot 72 | plt.show() 73 | -------------------------------------------------------------------------------- /validator_data/generate_validator_join_using_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from validator_data.validator import ValidatorJOIN, _execute_sql, _make_str_response, is_execution_correct, ValidatorJOINWithTrueSQL 4 | import argparse 5 | import os 6 | import re 7 | from multiprocessing import Pool 8 | 9 | def process_sample(sample): 10 | """Process a single sample.""" 11 | 12 | validator = Validator(endpoint_type=args.endpoint_type) 13 | 14 | try: 15 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql']) 16 | 17 | sample['predict_sql'] = sample['predict_sqls'][0] 18 | prompt, answer, execution_result = validator.validate(sample) 19 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 20 | 21 | sample['prompt_validator_join'] = prompt 22 | sample['is_correct'] = is_correct 23 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer 24 | sample['validator_join'] = answer 25 | sample['true_result'] = _make_str_response(*true_execution_result) 26 | sample['pred_result'] = _make_str_response(*execution_result) 27 | 28 | return sample 29 | except Exception as e: 30 | print(f"Error processing sample: {e}") 31 | return None 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json') 36 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl') 37 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai']) 38 | parser.add_argument('--use_hidden_sql', action='store_true') 39 | args = parser.parse_args() 40 | 41 | if args.use_hidden_sql: 42 | Validator = ValidatorJOINWithTrueSQL 43 | else: 44 | Validator = ValidatorJOIN 45 | 46 | data = [] 47 | with open(args.input_file) as fp: 48 | for line in fp: 49 | data.append(json.loads(line)) 50 | 51 | # Load saved output file if exists 52 | processed_keys = set() 53 | if os.path.exists(args.output_file): 54 | with open(args.output_file, 'r') as f: 55 | for line in f: 56 | sample = json.loads(line) 57 | processed_keys.add((sample['db_id'], sample['question'])) 58 | 59 | # Determine samples to process 60 | samples_to_process = [sample for sample in data if (sample['db_id'], sample['question']) not in processed_keys] 61 | 62 | # Open output file in append mode 63 | with open(args.output_file, 'a') as output_file: 64 | with Pool(8) as pool: 65 | for result in tqdm(pool.imap(process_sample, samples_to_process), total=len(samples_to_process)): 66 | if result is not None: 67 | output_file.write(json.dumps(result, ensure_ascii=False) + '\n') 68 | -------------------------------------------------------------------------------- /validator_data/generate_fixed_sql_using_fewshot_join.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import pandas as pd 4 | import argparse 5 | from validator import FixAgent, _execute_sql, _make_str_response, is_execution_correct 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--input_file', type=str, default='../data/llm_alignment/validator_join_bird_with_evidence_train.jsonl') 9 | parser.add_argument('--output_file', type=str, default='../data/llm_alignment/fixed_join_bird_with_evidence_train.jsonl') 10 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp']) 11 | args = parser.parse_args() 12 | 13 | PROMPT = open('./few_shot_prompt_fix_join.txt').read().strip() + """ 14 | ========= 15 | {schema} 16 | 17 | Matched contents are written in this format table.column (some values can be found in that column) 18 | {matched_content} 19 | 20 | Question: {question} 21 | 22 | SQL query: {sql_query} 23 | 24 | Feedback:{feedback} 25 | 26 | FIXED SQL:""" 27 | 28 | # load data from jsonl 29 | data = [] 30 | with open(args.input_file, 'r') as f: 31 | for line in f: 32 | data.append(json.loads(line)) 33 | 34 | fix_agent = FixAgent(prompt_template=PROMPT, endpoint_type=args.endpoint_type) 35 | 36 | output_file = open(args.output_file, 'a+') 37 | 38 | for isample in tqdm(range(0, len(data)), total=len(data)): 39 | sample = data[isample] 40 | 41 | sql = sample['predict_sql'] 42 | is_correct = sample['is_correct'] 43 | if sample['validator_join'] is None or "Conclude: correct" in sample['validator_join']: 44 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 45 | continue 46 | 47 | prompt = PROMPT.format( 48 | schema=sample['schema_sequence'], 49 | matched_content=sample['content_sequence'], 50 | question=sample['text'], 51 | sql_query=sql, 52 | feedback=sample['validator_join'] 53 | ) 54 | # print(prompt) 55 | answer = fix_agent.get_answer([{"role": "user", "content": prompt}]) 56 | 57 | execution_result = _execute_sql("../" + sample['db_path'], answer) 58 | 59 | print("-"*20) 60 | print(answer) 61 | # break 62 | sample['fixed_sql'] = answer 63 | sample['fixed_pred_result'] = _make_str_response(*execution_result) 64 | 65 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql']) 66 | is_fixed_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 67 | sample['is_fixed_correct'] = is_fixed_correct 68 | 69 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 70 | 71 | bird_results_dict = dict() 72 | for idx, sample in enumerate(data): 73 | if 'fixed_sql' in sample: 74 | predicted_sql = sample['fixed_sql'] 75 | else: 76 | predicted_sql = sample['predict_sql'] 77 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"] 78 | with open("predict_dev.json", "w", encoding = 'utf-8') as f: 79 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False)) 80 | -------------------------------------------------------------------------------- /data_processing/generate_planner_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | from planner import PlannerCombine, PlannerCombineWithTrueSQL 5 | import argparse 6 | from multiprocessing import Pool 7 | 8 | # add parse for input data file (train, dev) and output_file 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--input_file', type=str, default='../data/sft_bird_with_evidence_train_text2sql.json') 11 | parser.add_argument('--output_file', type=str, default='../data/planner/planner_select_bird_with_evidence_train.jsonl') 12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai']) 13 | parser.add_argument('--mode', type=str, choices=['select', 'condition', 'combine', 'combine_with_true_sql'], default='combine') 14 | parser.add_argument('--prompt', choices=['few-shot', 'cot'], default='few-shot') 15 | args = parser.parse_args() 16 | 17 | if args.input_file.endswith('.json'): 18 | data = json.load(open(args.input_file)) 19 | elif args.input_file.endswith('.jsonl'): 20 | data = [] 21 | with open(args.input_file, 'r') as f: 22 | for line in f: 23 | data.append(json.loads(line)) 24 | 25 | # load saved output file if exists 26 | if os.path.exists(args.output_file): 27 | old_output = [] 28 | with open(args.output_file, 'r') as f: 29 | for line in f: 30 | old_output.append(json.loads(line)) 31 | data[:len(old_output)] = old_output 32 | else: 33 | old_output = [] 34 | 35 | # open jsonl file for append contents 36 | # makedirs if not exists 37 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 38 | output_file = open(args.output_file, 'a+') 39 | 40 | if args.mode == 'combine': 41 | planner = PlannerCombine(endpoint_type=args.endpoint_type) 42 | elif args.mode == 'combine_with_true_sql': 43 | planner = PlannerCombineWithTrueSQL(endpoint_type=args.endpoint_type) 44 | 45 | if args.prompt == 'cot': 46 | planner.prompt_template = """{schema} 47 | 48 | Question: {question} 49 | External knowledge: {evidence} 50 | 51 | Use this hidden True SQL query to write correct analysis that derives to the correct answer. The True SQL query cannot be used in the analysis. 52 | Hidden True SQL query: {true_sql_query} 53 | 54 | Write your thought in short then write the final SQL query, answer in this format: 55 | [your short thought step-by-step] 56 | Final SQL query: 57 | ``` 58 | [SQL query] 59 | ``` 60 | """ 61 | 62 | def process_sample(sample): 63 | answer = planner.generate(sample) 64 | sample[f'planner_{args.mode}'] = answer 65 | return sample 66 | 67 | def main(): 68 | chunk_size = 4 69 | with open(args.output_file, 'a') as output_file: 70 | for i in tqdm(range(len(old_output), len(data), chunk_size), total=(len(data) - len(old_output))//chunk_size): 71 | chunk = data[i:i+chunk_size] 72 | pool = Pool(chunk_size) 73 | processed_samples = pool.map(process_sample, chunk) 74 | pool.close() 75 | 76 | if len(processed_samples) > 0: 77 | print(processed_samples[0][f'planner_{args.mode}']) 78 | 79 | for sample in processed_samples: 80 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /utils/load_pt_dataset.py: -------------------------------------------------------------------------------- 1 | # saves the SQL corpus to several binary files for pre-training CodeGen. following was helpful: 2 | # https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py 3 | 4 | import numpy as np 5 | import torch 6 | import random 7 | from torch.utils.data import IterableDataset, Dataset, DataLoader 8 | 9 | # class PretrainDataset(IterableDataset): 10 | # def __init__(self, pt_data_dir, block_size, epochs): 11 | # super().__init__() 12 | # self.corpus = np.memmap(pt_data_dir, dtype = np.uint16, mode = 'r') 13 | # self.block_size = block_size 14 | # self.epochs = epochs 15 | # self.length = len(self.corpus) // self.block_size 16 | 17 | # # return a tokenized sequence 18 | # def __iter__(self): 19 | # for _ in range(self.epochs): 20 | # start_idx_list = list(range(0, len(self.corpus), self.block_size)) 21 | # # for each epoch, shuffle the order of sequences 22 | # random.shuffle(start_idx_list) 23 | 24 | # for start_idx in start_idx_list: 25 | # input_ids = self.corpus[start_idx: start_idx + self.block_size] 26 | # # skip the sequence whose length is not equal to `block_size` 27 | # if len(input_ids) != self.block_size: 28 | # continue 29 | 30 | # input_ids = torch.from_numpy(input_ids.astype(np.int64)) 31 | # attention_mask = torch.ones(len(input_ids)) 32 | 33 | # yield {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} 34 | 35 | # # # return a tokenized sequence 36 | # # def __iter__(self): 37 | # # for _ in range(self.dataset_length): 38 | # # # randomly select a sequence of token ids from the tokenized corpus 39 | # # idx = random.randint(0, len(self.corpus) - self.block_size) 40 | # # input_ids = self.corpus[idx: idx + self.block_size] 41 | # # input_ids = torch.from_numpy(input_ids.astype(np.int64)) 42 | # # attention_mask = torch.ones(len(input_ids)) 43 | 44 | # # yield {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} 45 | 46 | # def __len__(self): 47 | # return self.length 48 | 49 | class PretrainDataset(Dataset): 50 | def __init__(self, pt_data_dir, block_size): 51 | super().__init__() 52 | self.corpus = np.memmap(pt_data_dir, dtype = np.uint16, mode = 'r') 53 | self.block_size = block_size 54 | self.length = len(self.corpus) // self.block_size 55 | 56 | # return a list of token ids in the corpus 57 | def __getitem__(self, index): 58 | input_ids = self.corpus[index * self.block_size : (index + 1) * self.block_size] 59 | 60 | input_ids = torch.from_numpy(input_ids.astype(np.int64)) 61 | attention_mask = torch.ones(len(input_ids)) 62 | 63 | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | if __name__ == "__main__": 69 | dataset = PretrainDataset("./data/pt_corpus/starcoder_corpus.bin", 6144) 70 | dataloader = DataLoader(dataset, batch_size = 4, shuffle = False, drop_last = True) 71 | for batch in dataloader: 72 | print("-"*20) 73 | print(len(dataset)) -------------------------------------------------------------------------------- /alignment-handbook/tests/test_model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import torch 18 | 19 | from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer 20 | from alignment.data import DEFAULT_CHAT_TEMPLATE 21 | 22 | 23 | class GetQuantizationConfigTest(unittest.TestCase): 24 | def test_4bit(self): 25 | model_args = ModelArguments(load_in_4bit=True) 26 | quantization_config = get_quantization_config(model_args) 27 | self.assertTrue(quantization_config.load_in_4bit) 28 | self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16) 29 | self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4") 30 | self.assertFalse(quantization_config.bnb_4bit_use_double_quant) 31 | 32 | def test_8bit(self): 33 | model_args = ModelArguments(load_in_8bit=True) 34 | quantization_config = get_quantization_config(model_args) 35 | self.assertTrue(quantization_config.load_in_8bit) 36 | 37 | def test_no_quantization(self): 38 | model_args = ModelArguments() 39 | quantization_config = get_quantization_config(model_args) 40 | self.assertIsNone(quantization_config) 41 | 42 | 43 | class GetTokenizerTest(unittest.TestCase): 44 | def setUp(self) -> None: 45 | self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") 46 | 47 | def test_right_truncation_side(self): 48 | tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right")) 49 | self.assertEqual(tokenizer.truncation_side, "right") 50 | 51 | def test_left_truncation_side(self): 52 | tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left")) 53 | self.assertEqual(tokenizer.truncation_side, "left") 54 | 55 | def test_default_chat_template(self): 56 | tokenizer = get_tokenizer(self.model_args, DataArguments()) 57 | self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) 58 | 59 | def test_chatml_chat_template(self): 60 | chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" 61 | tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template)) 62 | self.assertEqual(tokenizer.chat_template, chat_template) 63 | 64 | 65 | class GetPeftConfigTest(unittest.TestCase): 66 | def test_peft_config(self): 67 | model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99) 68 | peft_config = get_peft_config(model_args) 69 | self.assertEqual(peft_config.r, 42) 70 | self.assertEqual(peft_config.lora_alpha, 0.66) 71 | self.assertEqual(peft_config.lora_dropout, 0.99) 72 | 73 | def test_no_peft_config(self): 74 | model_args = ModelArguments(use_peft=False) 75 | peft_config = get_peft_config(model_args) 76 | self.assertIsNone(peft_config) 77 | -------------------------------------------------------------------------------- /data_processing/merge_val_fix_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Updated fixed SQL data saved to ../data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import json\n", 18 | "from copy import deepcopy\n", 19 | "\n", 20 | "# File paths\n", 21 | "fixed_sql_bird_file = '../data/multi-agents/fixed/gpt-4o-mini-fixed-bird_with_evidence_train.jsonl'\n", 22 | "validator_select_file = '../data/multi-agents/validator/gpt-4o-mini-validator_select_bird_with_evidence_train.jsonl'\n", 23 | "validator_condition_file = '../data/multi-agents/validator/gpt-4o-mini-validator_condition_bird_with_evidence_train.jsonl'\n", 24 | "validator_join_file = '../data/multi-agents/validator/gpt-4o-mini-validator_join_bird_with_evidence_train.jsonl'\n", 25 | "\n", 26 | "# Function to load JSONL files\n", 27 | "def load_jsonl(file_path):\n", 28 | " data = []\n", 29 | " with open(file_path, 'r', encoding='utf-8') as file:\n", 30 | " for line in file:\n", 31 | " data.append(json.loads(line))\n", 32 | " return data\n", 33 | "\n", 34 | "# Load all datasets\n", 35 | "fixed_sql_bird_data = load_jsonl(fixed_sql_bird_file)\n", 36 | "validator_select_data = load_jsonl(validator_select_file)\n", 37 | "validator_condition_data = load_jsonl(validator_condition_file)\n", 38 | "validator_join_data = load_jsonl(validator_join_file)\n", 39 | "\n", 40 | "# Process and add valid samples\n", 41 | "for sample_select, sample_condition, sample_join in zip(validator_select_data, validator_condition_data, validator_join_data):\n", 42 | "\n", 43 | " # Extract correctness feedback\n", 44 | " select_correct = sample_select.get('feedback_conclude')\n", 45 | " condition_correct = sample_condition.get('feedback_conclude')\n", 46 | " join_correct = sample_join.get('feedback_conclude')\n", 47 | "\n", 48 | " # If all are correct, add a new sample to fixed_sql_bird_data\n", 49 | " if select_correct and condition_correct and join_correct:\n", 50 | " new_sample = deepcopy(sample_select)\n", 51 | " new_sample = {\n", 52 | " \"validator_select\": sample_select,\n", 53 | " \"validator_condition\": sample_condition['validator_condition'],\n", 54 | " \"validator_join\": sample_join['validator_join'],\n", 55 | " \"fixed_sql\": [\"None\"] # Empty list as per instructions\n", 56 | " }\n", 57 | " fixed_sql_bird_data.append(new_sample)\n", 58 | "\n", 59 | "# Save the updated fixed SQL data\n", 60 | "output_file = '../data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl'\n", 61 | "with open(output_file, 'w', encoding='utf-8') as file:\n", 62 | " for entry in fixed_sql_bird_data:\n", 63 | " file.write(json.dumps(entry, ensure_ascii=False) + '\\n')\n", 64 | "\n", 65 | "print(f\"Updated fixed SQL data saved to {output_file}\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "handbook", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.10.14" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 2 97 | } 98 | -------------------------------------------------------------------------------- /validator_data/generate_validator_condition_using_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from validator import ValidatorCondition, _execute_sql, _make_str_response, is_execution_correct, ValidatorConditionWithTrueSQL 4 | import argparse 5 | import re 6 | import os 7 | from multiprocessing import Pool, Manager 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json') 11 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl') 12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai']) 13 | parser.add_argument('--use_hidden_sql', action='store_true') 14 | args = parser.parse_args() 15 | 16 | def process_sample(args): 17 | idx, sample, endpoint_type, output_file_lock, output_file_path = args 18 | 19 | try: 20 | validator = Validator(endpoint_type=endpoint_type) 21 | 22 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql']) 23 | 24 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planners'][0], re.DOTALL) 25 | if pred_sql_match is None: 26 | return None 27 | 28 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip() 29 | sample['predict_sql'] = pred_sql 30 | 31 | prompt, answer, execution_result = validator.validate(sample) 32 | answer = answer[0] 33 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0]) 34 | 35 | sample['is_correct'] = is_correct 36 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer 37 | sample['validator_condition'] = answer 38 | sample['true_result'] = _make_str_response(*true_execution_result) 39 | sample['pred_result'] = _make_str_response(*execution_result) 40 | 41 | # Write the result to the file 42 | with output_file_lock: 43 | with open(output_file_path, 'a') as output_file: 44 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 45 | 46 | return idx 47 | 48 | except Exception as e: 49 | print(f"Error processing sample {idx}: {e}") 50 | return None 51 | 52 | if __name__ == "__main__": 53 | if args.use_hidden_sql: 54 | Validator = ValidatorConditionWithTrueSQL 55 | else: 56 | Validator = ValidatorCondition 57 | 58 | # Load data 59 | data = [] 60 | with open(args.input_file) as fp: 61 | for line in fp: 62 | data.append(json.loads(line)) 63 | 64 | # Load old results 65 | if os.path.exists(args.output_file): 66 | processed_indices = set() 67 | with open(args.output_file, 'r') as f: 68 | for line in f: 69 | result = json.loads(line) 70 | processed_indices.add(f"{result['db_id']} {result['question']}") 71 | print(f"Loaded {len(processed_indices)} previously processed samples.") 72 | else: 73 | processed_indices = set() 74 | 75 | # Filter data to process only unprocessed samples 76 | unprocessed_data = [ 77 | (i, sample, args.endpoint_type, None, args.output_file) 78 | for i, sample in enumerate(data) 79 | if f"{sample['db_id']} {sample['question']}" not in processed_indices 80 | ] 81 | 82 | # Set up multiprocessing 83 | with Manager() as manager: 84 | output_file_lock = manager.Lock() 85 | 86 | # Add output_file_lock to each task 87 | tasks = [ 88 | (idx, sample, args.endpoint_type, output_file_lock, args.output_file) 89 | for idx, sample, _, _, _ in unprocessed_data 90 | ] 91 | 92 | with Pool(processes=12) as pool: 93 | # Use tqdm to monitor progress 94 | for _ in tqdm(pool.imap_unordered(process_sample, tasks), total=len(tasks)): 95 | pass 96 | 97 | print("Processing completed.") 98 | -------------------------------------------------------------------------------- /alignment-handbook/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Temp folders 163 | data/ 164 | wandb/ 165 | output/ 166 | -------------------------------------------------------------------------------- /alignment-handbook/src/alignment/release.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import re 18 | 19 | import packaging.version 20 | 21 | 22 | REPLACE_PATTERNS = { 23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), 24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), 25 | } 26 | REPLACE_FILES = { 27 | "init": "src/alignment/__init__.py", 28 | "setup": "setup.py", 29 | } 30 | README_FILE = "README.md" 31 | 32 | 33 | def update_version_in_file(fname, version, pattern): 34 | """Update the version in one file using a specific pattern.""" 35 | with open(fname, "r", encoding="utf-8", newline="\n") as f: 36 | code = f.read() 37 | re_pattern, replace = REPLACE_PATTERNS[pattern] 38 | replace = replace.replace("VERSION", version) 39 | code = re_pattern.sub(replace, code) 40 | with open(fname, "w", encoding="utf-8", newline="\n") as f: 41 | f.write(code) 42 | 43 | 44 | def global_version_update(version, patch=False): 45 | """Update the version in all needed files.""" 46 | for pattern, fname in REPLACE_FILES.items(): 47 | update_version_in_file(fname, version, pattern) 48 | 49 | 50 | def get_version(): 51 | """Reads the current version in the __init__.""" 52 | with open(REPLACE_FILES["init"], "r") as f: 53 | code = f.read() 54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] 55 | return packaging.version.parse(default_version) 56 | 57 | 58 | def pre_release_work(patch=False): 59 | """Do all the necessary pre-release steps.""" 60 | # First let's get the default version: base version if we are in dev, bump minor otherwise. 61 | default_version = get_version() 62 | if patch and default_version.is_devrelease: 63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") 64 | if default_version.is_devrelease: 65 | default_version = default_version.base_version 66 | elif patch: 67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" 68 | else: 69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0" 70 | 71 | # Now let's ask nicely if that's the right one. 72 | version = input(f"Which version are you releasing? [{default_version}]") 73 | if len(version) == 0: 74 | version = default_version 75 | 76 | print(f"Updating version to {version}.") 77 | global_version_update(version, patch=patch) 78 | 79 | 80 | def post_release_work(): 81 | """Do all the necessary post-release steps.""" 82 | # First let's get the current version 83 | current_version = get_version() 84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" 85 | current_version = current_version.base_version 86 | 87 | # Check with the user we got that right. 88 | version = input(f"Which version are we developing now? [{dev_version}]") 89 | if len(version) == 0: 90 | version = dev_version 91 | 92 | print(f"Updating version to {version}.") 93 | global_version_update(version) 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") 99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") 100 | args = parser.parse_args() 101 | if not args.post_release: 102 | pre_release_work(patch=args.patch) 103 | elif args.patch: 104 | print("Nothing to do after a patch :-)") 105 | else: 106 | post_release_work() 107 | -------------------------------------------------------------------------------- /alignment-handbook/src/alignment/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | from typing import Dict 17 | 18 | import torch 19 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer 20 | 21 | from accelerate import Accelerator 22 | from huggingface_hub import list_repo_files 23 | from huggingface_hub.utils._validators import HFValidationError 24 | from peft import LoraConfig, PeftConfig 25 | 26 | from .configs import DataArguments, ModelArguments 27 | from .data import DEFAULT_CHAT_TEMPLATE 28 | 29 | 30 | def get_current_device() -> int: 31 | """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" 32 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" 33 | 34 | 35 | def get_kbit_device_map() -> Dict[str, int] | None: 36 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" 37 | return {"": get_current_device()} if torch.cuda.is_available() else None 38 | 39 | 40 | def get_quantization_config(model_args) -> BitsAndBytesConfig | None: 41 | if model_args.load_in_4bit: 42 | quantization_config = BitsAndBytesConfig( 43 | load_in_4bit=True, 44 | bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models 45 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, 46 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, 47 | ) 48 | elif model_args.load_in_8bit: 49 | quantization_config = BitsAndBytesConfig( 50 | load_in_8bit=True, 51 | ) 52 | else: 53 | quantization_config = None 54 | 55 | return quantization_config 56 | 57 | 58 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 59 | """Get the tokenizer for the model.""" 60 | tokenizer = AutoTokenizer.from_pretrained( 61 | model_args.model_name_or_path, 62 | revision=model_args.model_revision, 63 | trust_remote_code=model_args.trust_remote_code 64 | ) 65 | if tokenizer.pad_token_id is None: 66 | tokenizer.pad_token_id = tokenizer.eos_token_id 67 | 68 | if data_args.truncation_side is not None: 69 | tokenizer.truncation_side = data_args.truncation_side 70 | 71 | # Set reasonable default for models without max length 72 | # if tokenizer.model_max_length > 100_000: 73 | # tokenizer.model_max_length = 4096 74 | 75 | if data_args.chat_template is not None: 76 | tokenizer.chat_template = data_args.chat_template 77 | elif tokenizer.chat_template is None: 78 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 79 | 80 | return tokenizer 81 | 82 | 83 | def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: 84 | if model_args.use_peft is False: 85 | return None 86 | 87 | peft_config = LoraConfig( 88 | r=model_args.lora_r, 89 | lora_alpha=model_args.lora_alpha, 90 | lora_dropout=model_args.lora_dropout, 91 | bias="none", 92 | task_type="CAUSAL_LM", 93 | target_modules=model_args.lora_target_modules, 94 | modules_to_save=model_args.lora_modules_to_save, 95 | ) 96 | 97 | return peft_config 98 | 99 | 100 | def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: 101 | try: 102 | # Try first if model on a Hub repo 103 | repo_files = list_repo_files(model_name_or_path, revision=revision) 104 | except HFValidationError: 105 | # If not, check local repo 106 | repo_files = os.listdir(model_name_or_path) 107 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files 108 | -------------------------------------------------------------------------------- /validator_data/utils.py: -------------------------------------------------------------------------------- 1 | from sql_metadata import Parser 2 | import re 3 | import os 4 | import sqlparse 5 | 6 | def remove_table_alias(s): 7 | try: 8 | tables_aliases = Parser(s).tables_aliases 9 | except Exception as e: 10 | return s 11 | 12 | new_tables_aliases = {} 13 | for i in range(1,11): 14 | if "t{}".format(i) in tables_aliases.keys(): 15 | new_tables_aliases["t{}".format(i)] = tables_aliases["t{}".format(i)] 16 | 17 | tables_aliases = new_tables_aliases 18 | for k, v in tables_aliases.items(): 19 | # remove AS clauses 20 | s = s.replace("AS " + k + " ", "") 21 | # replace table alias with thier original names 22 | s = s.replace(k, v) 23 | 24 | return s 25 | 26 | def extract_select_clause(sql_query): 27 | # Define a regex pattern to match the SELECT clause up to the FROM keyword 28 | pattern = re.compile(r"SELECT\s.*?\s(?=FROM)", re.IGNORECASE | re.DOTALL) 29 | 30 | # Search for the pattern in the SQL query 31 | match = pattern.search(sql_query) 32 | 33 | if match: 34 | # Return the matched portion (SELECT clause) 35 | return match.group(0).strip() 36 | else: 37 | # Return None if no match is found 38 | return None 39 | 40 | def get_table_columns_list(schema): 41 | columns = [] 42 | for table_data in schema['schema_items']: 43 | table_name = table_data['table_name'].lower() 44 | column_names = table_data['column_names'] 45 | column_names = [x.lower() for x in column_names] 46 | 47 | for column in column_names: 48 | columns.append(f"{table_name}.{column}") 49 | columns.append(f"{table_name}.`{column}`") 50 | columns.append(column) 51 | 52 | columns = list(set(columns)) 53 | return columns 54 | 55 | def get_columns_in_select_clause(sql_query, schema): 56 | column_list = get_table_columns_list(schema) 57 | select_clause = extract_select_clause(sql_query) 58 | 59 | if select_clause is None: 60 | return [] 61 | 62 | select_clause = remove_table_alias(sqlparse.format(select_clause, keyword_case = "upper", identifier_case = "lower")) 63 | try: 64 | sql_tokens = [token.value for token in Parser(select_clause.lower()).tokens] 65 | except Exception as e: 66 | print(e) 67 | sql_tokens = sql_query.lower().split() 68 | 69 | select_columns = [] 70 | for token in sql_tokens: 71 | if token in column_list: 72 | select_columns.append(token) 73 | return select_columns 74 | 75 | 76 | def get_equation_function_in_select_clause(sql_query): 77 | """ 78 | equation function includes min, max, avg, sum, count, divide, +, /, case when 79 | """ 80 | select_clause = extract_select_clause(sql_query) 81 | if select_clause is None: 82 | return [] 83 | norm_select_clause = remove_table_alias(sqlparse.format(select_clause, keyword_case = "upper", identifier_case = "lower")) 84 | 85 | try: 86 | sql_tokens = [token.value for token in Parser(norm_select_clause.lower()).tokens] 87 | except Exception as e: 88 | sql_tokens = norm_select_clause.lower().split() 89 | 90 | equation_functions = [] 91 | for token in sql_tokens: 92 | if token in ["min", "max", "avg", "sum", "count", "divide", "+", "/", "case", "when"]: 93 | equation_functions.append(token) 94 | 95 | return equation_functions 96 | 97 | def remove_tables_from_columns(columns): 98 | new_columns = [] 99 | for col in columns: 100 | new_columns.append(col.split('.')[-1]) 101 | return new_columns 102 | 103 | def check_columns_match(true_columns, pred_columns): 104 | true_columns = remove_tables_from_columns(true_columns) 105 | pred_columns = remove_tables_from_columns(pred_columns) 106 | 107 | # classify error types, unnecessary columns, missing columns, wrong order, return a string of error type 108 | if true_columns == pred_columns: 109 | return 'correct' 110 | else: 111 | if set(true_columns) == set(pred_columns): 112 | return 'incorrect: wrong order' 113 | elif set(true_columns) - set(pred_columns): 114 | return 'incorrect: missing columns ' + str(set(true_columns) - set(pred_columns)) 115 | elif set(pred_columns) - set(true_columns): 116 | return 'incorrect: unnecessary columns ' + str(set(pred_columns) - set(true_columns)) -------------------------------------------------------------------------------- /data_processing/generate_validator_fixer_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from datasets import Dataset, DatasetDict 5 | from planner import get_answer_llamacpp, get_answer_vllm, get_answer_openai 6 | from openai import OpenAI 7 | from dotenv import load_dotenv 8 | from planner import _make_str_response, _execute_sql, is_execution_correct 9 | import re 10 | from utils import norm_sql_query 11 | from tqdm import tqdm 12 | from multiprocessing import Pool 13 | 14 | # Set up argument parser 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--input_file', type=str, default='./data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl') 17 | parser.add_argument('--output_dir', type=str, default='./data/multi-agents/fixed/sft-gpt-4o-mini-validator-fixer-bird_with_evidence_train') 18 | parser.add_argument('--num_processes', type=int, default=16) 19 | args = parser.parse_args() 20 | 21 | # Define the prompt template 22 | PROMPT = """{schema} 23 | 24 | Question: {question} 25 | External knowledge: {evidence} 26 | 27 | Generated SQL query: {sql_query} 28 | 29 | Execution response: 30 | {execution_response} 31 | 32 | Feedback for the SQL query: 33 | """ 34 | 35 | COMPLETION = """ 38 | 39 | 40 | {feedback_condition} 41 | 42 | 43 | FIXED SQL: {fixed_sql}""" 44 | 45 | def norm_feedback(feedback, token): 46 | feedback = token + feedback.split(token)[-1] 47 | return feedback 48 | 49 | def extract_sql_in_code_block(pred_sql_text): 50 | sql_block_match = re.search(r"```(.+?)```", pred_sql_text, re.DOTALL) 51 | if sql_block_match: 52 | sql_query = sql_block_match.group(1).strip() 53 | if sql_query.startswith("sql"): 54 | sql_query = sql_query.replace("sql", "").strip() 55 | return sql_query 56 | else: 57 | return pred_sql_text 58 | 59 | def process_sample(index_sample): 60 | index, sample = index_sample 61 | feedback_select = sample['validator_select'] or 'SELECT.\nNone' 62 | feedback_condition = sample['validator_condition'] or "CONDITION.\nNone" 63 | feedback_join = sample['validator_join'] or "JOIN.\nNone" 64 | feedback_join = "JOIN." + feedback_join.split("JOIN.")[-1] 65 | 66 | feedback_select = norm_feedback(feedback_select, "SELECT.") 67 | feedback_condition = norm_feedback(feedback_condition, "CONDITION.") 68 | feedback_join = norm_feedback(feedback_join, "JOIN.") 69 | 70 | prompt = PROMPT.format( 71 | schema=sample['schema_sequence'], 72 | question=sample['question'], 73 | evidence=sample['evidence'], 74 | sql_query=sample['predict_sql'], 75 | execution_response=sample['pred_result'] 76 | ) 77 | 78 | fixed_sql = sample['fixed_sql'] 79 | if type(fixed_sql) == list: 80 | fixed_sql = fixed_sql[0] 81 | 82 | fixed_sql = extract_sql_in_code_block(fixed_sql) 83 | 84 | if fixed_sql != "None": 85 | true_result, has_error = _execute_sql("./" + sample["db_path"], sample["sql"]) 86 | pred_result, has_error = _execute_sql("./" + sample["db_path"], fixed_sql) 87 | 88 | if not is_execution_correct(true_result, pred_result): 89 | print("-"*20) 90 | print('True:', true_result) 91 | print('Pred:', pred_result) 92 | # completion = norm_sql_query(sample['sql'], sample['schema']) 93 | fixed_sql = sample['sql'] 94 | 95 | completion = COMPLETION.format( 96 | feedback_select=feedback_select, 97 | feedback_condition=feedback_condition, 98 | # feedback_join=feedback_join, 99 | fixed_sql=fixed_sql 100 | ) 101 | 102 | return { 103 | 'prompt_id': str(index), 104 | 'messages': { 105 | 'prompt': prompt, 106 | 'completion': completion 107 | } 108 | } 109 | 110 | def main(): 111 | with open(args.input_file) as fp: 112 | data = [json.loads(line) for line in fp] 113 | 114 | with Pool(processes=args.num_processes) as pool: 115 | results = list(tqdm(pool.imap(process_sample, enumerate(data)), total=len(data))) 116 | 117 | sft_data = [result for result in results if result is not None] 118 | 119 | dataset = DatasetDict({ 120 | 'train': Dataset.from_list(sft_data), 121 | 'test': Dataset.from_list(sft_data[:100]), 122 | }) 123 | dataset.save_to_disk(args.output_dir) 124 | print(dataset) 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MATS: A Multi-agent Text2SQL Framework using Small Language Models and Execution Feedback 2 | 3 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=thanhdath.mats-sql) 4 | 5 | MATS is a multi-agent framework for Text2SQL using small language models and execution feedback to improve query accuracy. It employs multiple specialized agents—including schema insight agent, planner, validator, fix agent, and selection agent. Some components of this framework are adapted from [CodeS](https://github.com/RUCKBReasoning/codes) (for schema filtering) and [alignment-handbook](https://github.com/huggingface/alignment-handbook) (for supervised fine-tuning and ORPO training). 6 | 7 | **1. To set up the environment** 8 | ``` 9 | conda env create -n mats -f environment.yml 10 | conda activate mats 11 | ``` 12 | 13 | **2. Run Evaluation on BIRD**: 14 | 15 | First serve the models with VLLM. 16 | ``` 17 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-bird-planner --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.3 --enable-prefix-caching 18 | 19 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-bird-validator --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.2 --enable-prefix-caching 20 | 21 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-bird-fixed --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.2 --enable-prefix-caching 22 | 23 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-bird-selection --host 0.0.0.0 --port 8006 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name selection --gpu-memory-utilization 0.3 --enable-prefix-caching 24 | ``` 25 | 26 | Run evaluation: 27 | ``` 28 | eval_file=data/evaluate/orpo-llama-3-iter-2-end2end-bird_dev.jsonl 29 | rm $eval_file 30 | PYTHONPATH=. python evaluate_end2end.py \ 31 | --input_file data/schema_insight_bird_with_evidence_dev_text2sql.json \ 32 | --output_file $eval_file \ 33 | --model-name llama --mode test --n_return 10 --temperature 1.0 --api_host http://localhost:8003 --n_processes 16 34 | 35 | python compute_acc.py --pred_file $eval_file 36 | ``` 37 | 38 | 39 | 40 | 41 | **3. To run evaluation on Spider**: 42 | 43 | First serve the models with VLLM. 44 | ``` 45 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-spider-planner --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.3 --enable-prefix-caching 46 | 47 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-spider-validator --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.2 --enable-prefix-caching 48 | 49 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-spider-fixed --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.2 --enable-prefix-caching 50 | 51 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-spider-selection --host 0.0.0.0 --port 8006 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name selection --gpu-memory-utilization 0.3 --enable-prefix-caching 52 | ``` 53 | 54 | Run evaluation: 55 | ``` 56 | eval_file=data/evaluate/orpo-llama-3-iter-2-end2end-spider_dev.jsonl 57 | rm $eval_file 58 | PYTHONPATH=. python evaluate_end2end.py \ 59 | --input_file data/schema_insight_spider_dev_text2sql.json \ 60 | --output_file $eval_file \ 61 | --model-name llama --mode test --n_return 10 --temperature 1.0 --api_host http://localhost:8003 --n_processes 16 62 | 63 | python compute_acc.py --pred_file $eval_file 64 | ``` 65 | 66 | 67 | **4. For training agents** 68 | 69 | The Schema Filtering is inherited from [CodeS](https://github.com/RUCKBReasoning/codes). 70 | 71 | To train other agents, see the code in ***alignment-handbook/***, here we modified the repository [alignment-handbook](https://github.com/huggingface/alignment-handbook) for supervised-finetuning and ORPO on the completion part only. The config files could be found in **alignment-handbook/recipes/**. 72 | 73 | **Note**: Currently this work is under review. The model and training dataset will be publicly available upon acceptance. 74 | 75 | 76 | ----------- 77 | **Backup Statistics** 78 | 79 | ![Visitors](https://margherita-gustatory-zane.ngrok-free.dev/badge/thanhdath%2Fmats-sql.svg?ngrok-skip-browser-warning=true) 80 | -------------------------------------------------------------------------------- /data_processing/generate_sft_data_for_validator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datasets import Dataset, DatasetDict 4 | import numpy as np 5 | from planner import _make_str_response, _execute_sql, is_execution_correct 6 | from multiprocessing import Pool, cpu_count 7 | from functools import partial 8 | from tqdm import tqdm 9 | from multiprocessing import Manager 10 | 11 | # Set seed 12 | np.random.seed(100) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--input_file', type=str, default='../data/llm_alignment/validator_join_bird_with_evidence_train.jsonl') 16 | parser.add_argument('--output_dir', type=str, default='../data/llm_alignment/sft_data_for_validator_join_bird_with_evidence_train') 17 | args = parser.parse_args() 18 | 19 | def norm_completion(completion): 20 | """Normalize the completion by removing unwanted lines.""" 21 | lines = completion.split('\n') 22 | filter_lines = [line for line in lines if "broaden the criteria" not in line] 23 | completion = "\n".join(filter_lines) 24 | return completion 25 | 26 | PROMPT = """Generate feedbacks to fix the following SQL query: 27 | {schema} 28 | 29 | Question: {question} 30 | External knowledge: {evidence} 31 | 32 | SQL query: {sql_query} 33 | 34 | Execution response: 35 | {execution_response} 36 | 37 | Feedback:""" 38 | 39 | def process_sample(sample, index, input_file): 40 | """Process a single sample from the dataset.""" 41 | key, token = None, None 42 | if '_join' in input_file: 43 | key = 'validator_join' 44 | token = "JOIN." 45 | elif '_select' in input_file: 46 | key = 'validator_select' 47 | token = "SELECT." 48 | elif '_order' in input_file: 49 | key = 'validator_order' 50 | token = "ORDER BY." 51 | elif '_condition' in input_file: 52 | key = 'validator_condition' 53 | token = "CONDITION." 54 | 55 | feedback = sample.get(key) 56 | if not feedback: 57 | return None 58 | 59 | if type(feedback) == list: 60 | feedback = feedback[0] 61 | 62 | if f"prompt_{key}" in sample: 63 | prompt = sample[f"prompt_{key}"] 64 | prompt_completion = prompt + feedback 65 | else: 66 | prompt_completion = "\n" + feedback 67 | 68 | feedback = token + prompt_completion.split(token)[-1] 69 | prompt = PROMPT.format(schema=sample['schema_sequence'], 70 | question=sample['question'], 71 | evidence=sample['evidence'], 72 | sql_query=sample['predict_sql'], 73 | execution_response=sample['pred_result']) 74 | 75 | completion = feedback 76 | if isinstance(completion, list): 77 | completion = completion[0] 78 | 79 | completion = norm_completion(completion) 80 | prompt_id = f"{index}" 81 | 82 | true_result, _ = _execute_sql("./" + sample["db_path"], sample["sql"]) 83 | pred_result, _ = _execute_sql("./" + sample["db_path"], sample['predict_sql']) 84 | 85 | is_pred_sql_correct = is_execution_correct(true_result, pred_result) 86 | feedback_conclude_correct = completion is None or 'Conclude: correct' in completion 87 | 88 | if is_pred_sql_correct and not feedback_conclude_correct: # bad case 89 | return None 90 | 91 | return { 92 | 'prompt_id': prompt_id, 93 | 'messages': { 94 | 'prompt': prompt, 95 | 'completion': completion 96 | } 97 | } 98 | 99 | def main(): 100 | # Load JSONL data 101 | with open(args.input_file, 'r') as f: 102 | data = [json.loads(line) for line in f] 103 | 104 | # Use multiprocessing to process samples with tqdm 105 | print('Start processing samples...') 106 | with Manager() as manager: 107 | results_list = manager.list() # Shared list for results 108 | total_samples = len(data) 109 | 110 | with tqdm(total=total_samples, desc="Processing Samples") as pbar: 111 | def update_progress(result): 112 | if result is not None: 113 | results_list.append(result) 114 | pbar.update() 115 | 116 | with Pool(processes=24) as pool: 117 | partial_process_sample = partial(process_sample, input_file=args.input_file) 118 | for idx, sample in enumerate(data): 119 | pool.apply_async(partial_process_sample, args=(sample, idx), callback=update_progress) 120 | 121 | pool.close() 122 | pool.join() 123 | 124 | # Convert results to a normal list 125 | sft_data = list(results_list) 126 | 127 | # Shuffle and create DatasetDict 128 | np.random.shuffle(sft_data) 129 | 130 | dataset = DatasetDict({ 131 | 'train': Dataset.from_list(sft_data), 132 | 'test': Dataset.from_list(sft_data[:100]), 133 | }) 134 | 135 | print(dataset) 136 | dataset.save_to_disk(args.output_dir) 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /data_processing/generate_sft_data_for_fix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from datasets import Dataset, DatasetDict 5 | from planner import get_answer_llamacpp, get_answer_vllm, get_answer_openai 6 | from openai import OpenAI 7 | from dotenv import load_dotenv 8 | from planner import _make_str_response, _execute_sql, is_execution_correct 9 | import re 10 | from utils import norm_sql_query 11 | from tqdm import tqdm 12 | from multiprocessing import Pool 13 | 14 | # Set up argument parser 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--input_file', type=str, default='data/multi-agents/fixed/gpt-4o-mini-fixed-bird_with_evidence_train.jsonl') 17 | parser.add_argument('--output_dir', type=str, default='./data/multi-agents/fixed/sft-gpt-4o-mini-fixed-bird_with_evidence_train') 18 | parser.add_argument('--num_processes', type=int, default=16) 19 | args = parser.parse_args() 20 | 21 | # Define the prompt template 22 | PROMPT = """{schema} 23 | 24 | Question: {question} 25 | External knowledge: {evidence} 26 | 27 | Generated SQL query: {sql_query} 28 | 29 | Execution response: 30 | {execution_response} 31 | 32 | Feedback for the SQL query: 33 | {feedback_select} 34 | 35 | {feedback_condition} 36 | 37 | {feedback_join} 38 | 39 | FIXED SQL:""" 40 | 41 | def norm_feedback(feedback, token): 42 | feedback = token + feedback.split(token)[-1] 43 | return feedback 44 | 45 | def extract_sql_in_code_block(pred_sql_text): 46 | sql_block_match = re.search(r"```(.+?)```", pred_sql_text, re.DOTALL) 47 | if sql_block_match: 48 | sql_query = sql_block_match.group(1).strip() 49 | if sql_query.startswith("sql"): 50 | sql_query = sql_query.replace("sql", "").strip() 51 | return sql_query 52 | else: 53 | return pred_sql_text 54 | 55 | def process_sample(index_sample): 56 | index, sample = index_sample 57 | feedback_select = sample['validator_select'] or 'SELECT.\nNone' 58 | feedback_condition = sample['validator_condition'] or "CONDITION.\nNone" 59 | feedback_join = sample['validator_join'] or "JOIN.\nNone" 60 | feedback_join = "JOIN." + feedback_join.split("JOIN.")[-1] 61 | feedback_order = sample['validator_order'] or "ORDER BY.\nNone" 62 | 63 | feedback_select = norm_feedback(feedback_select, "SELECT.") 64 | feedback_condition = norm_feedback(feedback_condition, "CONDITION.") 65 | feedback_join = norm_feedback(feedback_join, "JOIN.") 66 | feedback_order = norm_feedback(feedback_order, "ORDER BY.") 67 | 68 | select_correct = 'Conclude: correct' in feedback_select or feedback_select == 'SELECT.\nNone' 69 | condition_correct = 'Conclude: correct' in feedback_condition or feedback_condition == 'CONDITION.\nNone' 70 | join_correct = 'Conclude: correct' in feedback_join or feedback_join == 'JOIN.\nNone' 71 | order_correct = 'Conclude: correct' in feedback_order or feedback_order == 'ORDER BY.\nNone' 72 | 73 | if select_correct: 74 | feedback_select = "" 75 | if condition_correct: 76 | feedback_condition = "" 77 | if join_correct: 78 | feedback_join = "" 79 | if order_correct: 80 | feedback_order = "" 81 | 82 | # if select_correct and condition_correct and join_correct and order_correct: 83 | # return None 84 | 85 | prompt = PROMPT.format( 86 | schema=sample['schema_sequence'], 87 | question=sample['question'], 88 | evidence=sample['evidence'], 89 | sql_query=sample['predict_sql'], 90 | execution_response=sample['pred_result'], 91 | feedback_select=feedback_select, 92 | feedback_condition=feedback_condition, 93 | feedback_join=feedback_join, 94 | # feedback_order=feedback_order 95 | ) 96 | 97 | completion = sample['fixed_sql'] 98 | if type(completion) == list: 99 | completion = completion[0] 100 | 101 | fixed_sql = extract_sql_in_code_block(completion) 102 | # completion = norm_sql_query(fixed_sql, sample['schema']) 103 | completion = fixed_sql 104 | 105 | true_result, has_error = _execute_sql("./" + sample["db_path"], sample["sql"]) 106 | pred_result, has_error = _execute_sql("./" + sample["db_path"], fixed_sql) 107 | 108 | if not is_execution_correct(true_result, pred_result): 109 | print("-"*20) 110 | print('True:', true_result) 111 | print('Pred:', pred_result) 112 | # completion = norm_sql_query(sample['sql'], sample['schema']) 113 | completion = sample['sql'] 114 | # return None 115 | 116 | return { 117 | 'prompt_id': str(index), 118 | 'messages': { 119 | 'prompt': prompt, 120 | 'completion': completion 121 | } 122 | } 123 | 124 | def main(): 125 | with open(args.input_file) as fp: 126 | data = [json.loads(line) for line in fp] 127 | 128 | with Pool(processes=args.num_processes) as pool: 129 | results = list(tqdm(pool.imap(process_sample, enumerate(data)), total=len(data))) 130 | 131 | sft_data = [result for result in results if result is not None] 132 | 133 | dataset = DatasetDict({ 134 | 'train': Dataset.from_list(sft_data), 135 | 'test': Dataset.from_list(sft_data[:100]), 136 | }) 137 | dataset.save_to_disk(args.output_dir) 138 | print(dataset) 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /alignment-handbook/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py 16 | 17 | 18 | import re 19 | import shutil 20 | from pathlib import Path 21 | 22 | from setuptools import find_packages, setup 23 | 24 | 25 | # Remove stale alignment.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 26 | stale_egg_info = Path(__file__).parent / "alignment.egg-info" 27 | if stale_egg_info.exists(): 28 | print( 29 | ( 30 | "Warning: {} exists.\n\n" 31 | "If you recently updated alignment, this is expected,\n" 32 | "but it may prevent alignment from installing in editable mode.\n\n" 33 | "This directory is automatically generated by Python's packaging tools.\n" 34 | "I will remove it now.\n\n" 35 | "See https://github.com/pypa/pip/issues/5466 for details.\n" 36 | ).format(stale_egg_info) 37 | ) 38 | shutil.rmtree(stale_egg_info) 39 | 40 | 41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any. 42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version 43 | _deps = [ 44 | "accelerate==0.23.0", 45 | "bitsandbytes==0.41.2.post2", 46 | "black==23.1.0", 47 | "datasets==2.14.6", 48 | "deepspeed==0.12.2", 49 | "einops>=0.6.1", 50 | "evaluate==0.4.0", 51 | "flake8>=6.0.0", 52 | "hf-doc-builder>=0.4.0", 53 | "huggingface-hub>=0.14.1,<1.0", 54 | "isort>=5.12.0", 55 | "ninja>=1.11.1", 56 | "numpy>=1.24.2", 57 | "packaging>=23.0", 58 | "parameterized>=0.9.0", 59 | "peft==0.6.1", 60 | "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` 61 | "pytest", 62 | "safetensors>=0.3.3", 63 | "scipy", 64 | "tensorboard", 65 | "torch==2.1.0", 66 | "transformers==4.35.0", 67 | "trl==0.7.4", 68 | "jinja2>=3.0.0", 69 | "tqdm>=4.64.1", 70 | ] 71 | 72 | # this is a lookup table with items like: 73 | # 74 | # tokenizers: "tokenizers==0.9.4" 75 | # packaging: "packaging" 76 | # 77 | # some of the values are versioned whereas others aren't. 78 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} 79 | 80 | 81 | def deps_list(*pkgs): 82 | return [deps[pkg] for pkg in pkgs] 83 | 84 | 85 | extras = {} 86 | extras["tests"] = deps_list("pytest", "parameterized") 87 | extras["torch"] = deps_list("torch") 88 | extras["quality"] = deps_list("black", "isort", "flake8") 89 | extras["docs"] = deps_list("hf-doc-builder") 90 | extras["dev"] = extras["docs"] + extras["quality"] + extras["tests"] 91 | 92 | # core dependencies shared across the whole project - keep this to a bare minimum :) 93 | install_requires = [ 94 | deps["accelerate"], 95 | deps["bitsandbytes"], 96 | deps["einops"], 97 | deps["evaluate"], 98 | deps["datasets"], 99 | deps["deepspeed"], 100 | deps["huggingface-hub"], 101 | deps["jinja2"], 102 | deps["ninja"], 103 | deps["numpy"], 104 | deps["packaging"], # utilities from PyPA to e.g., compare versions 105 | deps["peft"], 106 | deps["protobuf"], 107 | deps["safetensors"], 108 | deps["scipy"], 109 | deps["tensorboard"], 110 | deps["tqdm"], # progress bars in model download and training scripts 111 | deps["transformers"], 112 | deps["trl"], 113 | ] 114 | 115 | setup( 116 | name="alignment-handbook", 117 | version="0.2.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 118 | author="The Hugging Face team (past and future)", 119 | author_email="lewis@huggingface.co", 120 | description="The Alignment Handbook", 121 | long_description=open("README.md", "r", encoding="utf-8").read(), 122 | long_description_content_type="text/markdown", 123 | keywords="nlp deep learning rlhf llm", 124 | license="Apache", 125 | url="https://github.com/huggingface/alignment-handbook", 126 | package_dir={"": "src"}, 127 | packages=find_packages("src"), 128 | zip_safe=False, 129 | extras_require=extras, 130 | python_requires=">=3.10.9", 131 | install_requires=install_requires, 132 | classifiers=[ 133 | "Development Status :: 3 - Alpha", 134 | "Intended Audience :: Developers", 135 | "Intended Audience :: Education", 136 | "Intended Audience :: Science/Research", 137 | "License :: OSI Approved :: Apache Software License", 138 | "Operating System :: OS Independent", 139 | "Programming Language :: Python :: 3", 140 | "Programming Language :: Python :: 3.10", 141 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 142 | ], 143 | ) 144 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List 4 | 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | from torch.optim import Optimizer 7 | 8 | ''' 9 | copy from the source code of pl_bolts 10 | ''' 11 | 12 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 13 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr 14 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min. 15 | 16 | .. warning:: 17 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 18 | after each iteration as calling it after each epoch will keep the starting lr at 19 | warmup_start_lr for the first epoch which is 0 in most cases. 20 | 21 | .. warning:: 22 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 23 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 24 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 25 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 26 | train and validation methods. 27 | 28 | Example: 29 | >>> import torch.nn as nn 30 | >>> from torch.optim import Adam 31 | >>> # 32 | >>> layer = nn.Linear(10, 1) 33 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 34 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 35 | >>> # the default case 36 | >>> for epoch in range(40): 37 | ... # train(...) 38 | ... # validate(...) 39 | ... scheduler.step() 40 | >>> # passing epoch param case 41 | >>> for epoch in range(40): 42 | ... scheduler.step(epoch) 43 | ... # train(...) 44 | ... # validate(...) 45 | """ 46 | 47 | def __init__( 48 | self, 49 | optimizer: Optimizer, 50 | warmup_epochs: int, 51 | max_epochs: int, 52 | warmup_start_lr: float = 0.0, 53 | eta_min: float = 0.0, 54 | last_epoch: int = -1, 55 | ) -> None: 56 | """ 57 | Args: 58 | optimizer (Optimizer): Wrapped optimizer. 59 | warmup_epochs (int): Maximum number of iterations for linear warmup 60 | max_epochs (int): Maximum number of iterations 61 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 62 | eta_min (float): Minimum learning rate. Default: 0. 63 | last_epoch (int): The index of last epoch. Default: -1. 64 | """ 65 | self.warmup_epochs = warmup_epochs 66 | self.max_epochs = max_epochs 67 | self.warmup_start_lr = warmup_start_lr 68 | self.eta_min = eta_min 69 | 70 | super().__init__(optimizer, last_epoch) 71 | 72 | def get_lr(self) -> List[float]: 73 | """Compute learning rate using chainable form of the scheduler.""" 74 | if not self._get_lr_called_within_step: 75 | warnings.warn( 76 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 77 | UserWarning, 78 | ) 79 | 80 | if self.last_epoch == 0: 81 | return [self.warmup_start_lr] * len(self.base_lrs) 82 | if self.last_epoch < self.warmup_epochs: 83 | return [ 84 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 85 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 86 | ] 87 | if self.last_epoch == self.warmup_epochs: 88 | return self.base_lrs 89 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 90 | return [ 91 | group["lr"] 92 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 93 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 94 | ] 95 | 96 | return [ 97 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 98 | / ( 99 | 1 100 | + math.cos( 101 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 102 | ) 103 | ) 104 | * (group["lr"] - self.eta_min) 105 | + self.eta_min 106 | for group in self.optimizer.param_groups 107 | ] 108 | 109 | def _get_closed_form_lr(self) -> List[float]: 110 | """Called when epoch is passed as a param to the `step` function of the scheduler.""" 111 | if self.last_epoch < self.warmup_epochs: 112 | return [ 113 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 114 | for base_lr in self.base_lrs 115 | ] 116 | 117 | return [ 118 | self.eta_min 119 | + 0.5 120 | * (base_lr - self.eta_min) 121 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 122 | for base_lr in self.base_lrs 123 | ] 124 | -------------------------------------------------------------------------------- /utils/load_sft_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import gc 4 | 5 | from datasets import Dataset 6 | from torch.utils.data import Dataset 7 | from schema_item_filter import SchemaItemClassifierInference, filter_schema, filter_schema_purple 8 | from utils.db_utils import get_db_schema_sequence, get_matched_content_sequence 9 | 10 | def prepare_text2sql_prefix_sequence(data): 11 | prompt = f"""Convert the question to SQL query. 12 | {data["schema_sequence"]} 13 | {data["content_sequence"]} 14 | question: {data["text"]}""" 15 | return prompt 16 | 17 | def prepare_inputs_and_labels(prefix_seq, target_seq, tokenizer, max_tokens): 18 | prefix_ids = [tokenizer.bos_token_id] + tokenizer(prefix_seq , truncation = False)["input_ids"] 19 | target_ids = tokenizer(target_seq, truncation = False)["input_ids"] + [tokenizer.eos_token_id] 20 | 21 | seq_length = len(prefix_ids) + len(target_ids) 22 | if seq_length <= max_tokens: # pad inputs with pad_token_id 23 | pad_length = max_tokens - seq_length 24 | input_ids = prefix_ids + target_ids + [tokenizer.pad_token_id] * pad_length 25 | # tell the model to ignore the padding tokens when performing (masked) self-attention 26 | attention_mask = [1] * seq_length + [0] * pad_length 27 | # only target_ids produces gradients 28 | labels = [-100] * len(prefix_ids) + target_ids + [-100] * pad_length 29 | else: # no padding 30 | print("the current input sequence exceeds the max_tokens, we will truncate it.") 31 | input_ids = prefix_ids + target_ids 32 | # pre-truncate input ids 33 | input_ids = [tokenizer.bos_token_id] + input_ids[-(max_tokens-1):] 34 | attention_mask = [1] * max_tokens 35 | # only target_ids produces gradients 36 | labels = [-100] * len(prefix_ids) + target_ids 37 | # pre-truncate labels 38 | labels = labels[-max_tokens:] 39 | 40 | return { 41 | "input_ids": torch.tensor(input_ids, dtype = torch.int64), 42 | "attention_mask": torch.tensor(attention_mask, dtype = torch.int64), 43 | "labels": torch.tensor(labels, dtype = torch.int64) 44 | } 45 | 46 | # def prepare_inputs(prefix_seq, tokenizer, max_prefix_length): 47 | # input_ids = [tokenizer.bos_token_id] + tokenizer(prefix_seq , truncation = False)["input_ids"] 48 | 49 | # if len(input_ids) > max_prefix_length: 50 | # print("the current input sequence exceeds the max_tokens, we will truncate it.") 51 | # input_ids = [tokenizer.bos_token_id] + input_ids[-(max_prefix_length-1):] 52 | 53 | # attention_mask = [1] * len(input_ids) 54 | 55 | # return { 56 | # "input_ids": torch.tensor(input_ids, dtype = torch.int64), 57 | # "attention_mask": torch.tensor(attention_mask, dtype = torch.int64) 58 | # } 59 | 60 | def prepare_inputs(prefix_seq, tokenizer, max_prefix_length): 61 | messages = [{ 62 | 'role': 'user', 63 | 'content': prefix_seq 64 | }] 65 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 66 | 67 | input_ids = tokenizer(prompt , truncation = False)["input_ids"] 68 | 69 | if len(input_ids) > max_prefix_length: 70 | print("the current input sequence exceeds the max_tokens, we will truncate it.") 71 | input_ids = input_ids[-(max_prefix_length-1):] 72 | 73 | attention_mask = [1] * len(input_ids) 74 | 75 | return { 76 | "input_ids": torch.tensor(input_ids, dtype = torch.int64), 77 | "attention_mask": torch.tensor(attention_mask, dtype = torch.int64) 78 | } 79 | 80 | class SFTSQLGenerationDataset(Dataset): 81 | def __init__(self, text2sql_data_dir, tokenizer, max_tokens, mode, table_num, column_num, threshold, sic_path, do_filter_schema=True): 82 | super().__init__() 83 | dataset = json.load(open(text2sql_data_dir)) 84 | 85 | print("apply filtering strategies...") 86 | if do_filter_schema: 87 | if mode == "train": 88 | dataset = filter_schema(dataset, "train", None, table_num, column_num, threshold=threshold) 89 | elif mode == "eval": 90 | sic = SchemaItemClassifierInference(sic_path) 91 | dataset = filter_schema(dataset, "eval", sic, table_num, column_num, threshold=threshold) 92 | # dataset = filter_schema_purple(dataset, "eval", "/home/datht/llmsql/selector/spider-dev/spider-dev-selector-t0.02-value-samples.json") 93 | del sic 94 | torch.cuda.empty_cache() 95 | 96 | # prepare schema sequence and content sequence 97 | for data in dataset: 98 | data["schema_sequence"] = get_db_schema_sequence(data["schema"]) 99 | # data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) 100 | 101 | self.mode = mode 102 | self.dataset = dataset 103 | self.tokenizer = tokenizer 104 | self.max_tokens = max_tokens 105 | 106 | def __getitem__(self, index): 107 | data = self.dataset[index] 108 | prefix_seq = prepare_text2sql_prefix_sequence(data) 109 | if index < 2: 110 | print(prefix_seq) 111 | 112 | if self.mode == "train": 113 | target_seq = data["sql"] 114 | return prepare_inputs_and_labels(prefix_seq, target_seq, self.tokenizer, self.max_tokens) 115 | elif self.mode == "eval": 116 | return prepare_inputs(prefix_seq, self.tokenizer, self.max_tokens) 117 | 118 | def __len__(self): 119 | return len(self.dataset) -------------------------------------------------------------------------------- /data_processing/generate_sft_data_for_planner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import re 5 | from datasets import Dataset, DatasetDict 6 | from tqdm import tqdm 7 | import sqlite3 8 | from func_timeout import func_timeout, FunctionTimedOut 9 | from planner import _make_str_response, _execute_sql, is_execution_correct 10 | from utils import norm_sql_query 11 | from multiprocessing import Pool 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train.jsonl') 15 | parser.add_argument('--raw_train_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train.jsonl') 16 | parser.add_argument('--output_dir', type=str, default='../data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_bird_with_evidence_train/') 17 | parser.add_argument('--error_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train-error-turn-1.jsonl') 18 | parser.add_argument('--use_groundtruth', action='store_true') 19 | parser.add_argument('--no_filter', action='store_true') 20 | args = parser.parse_args() 21 | 22 | PROMPT = """{schema} 23 | 24 | Question: {question} 25 | External knowledge: {evidence} 26 | 27 | Planning: 28 | """ 29 | # PROMPT = """{schema} 30 | 31 | # Question: {question} 32 | # """ 33 | 34 | # Helper function for processing each sample 35 | def process_sample(args): 36 | isample, sample, raw_sample, use_groundtruth, no_filter = args 37 | schema = raw_sample['schema_sequence'] 38 | question = sample['question'] 39 | evidence = sample['evidence'] 40 | 41 | key = 'planner_combine_with_true_sql' 42 | feedback = sample[key] 43 | if feedback is None or len(feedback) == 0: 44 | return None, None # Indicate empty result 45 | 46 | if isinstance(feedback, list): 47 | feedback = feedback[0] 48 | 49 | prompt = PROMPT.format(schema=schema, question=question, evidence=evidence) 50 | 51 | if use_groundtruth: 52 | completion = sample['sql'] 53 | # completion = norm_sql_query(sample['sql'], raw_sample['schema']) 54 | else: 55 | # Extract SQL query using regex 56 | pred_sql_match = re.search(r"(?<=Final SQL query:).*?```(.*?)```", feedback, re.DOTALL) 57 | if pred_sql_match is None: 58 | pred_sql = " " 59 | else: 60 | pred_sql = pred_sql_match.group(1).strip() 61 | if pred_sql.startswith("sql"): 62 | pred_sql = pred_sql[3:].strip() 63 | 64 | # norm_pred_sql = norm_sql_query(pred_sql, raw_sample['schema']) 65 | # feedback = feedback.replace(pred_sql, norm_pred_sql) 66 | 67 | if not no_filter: 68 | true_result, has_error_true = _execute_sql("./" + sample["db_path"], sample["sql"]) 69 | pred_result, has_error_pred = _execute_sql("./" + sample["db_path"], pred_sql) 70 | # norm_pred_result, has_error_pred = _execute_sql("./" + sample["db_path"], norm_pred_sql) 71 | 72 | # if not is_execution_correct(pred_result, norm_pred_result): 73 | # # print to debug 74 | # print("-" * 20) 75 | # print("Norm SQL:", norm_pred_sql) 76 | # print("Pred SQL:", pred_sql) 77 | # print("Norm Result:", norm_pred_result) 78 | # print("Pred Result:", pred_result) 79 | 80 | if not is_execution_correct(true_result, pred_result): 81 | # sample['true_result'] = _make_str_response(true_result, has_error_true) 82 | # sample['pred_result'] = _make_str_response(pred_result, has_error_pred) 83 | return None, sample # Return sample with error 84 | 85 | completion = feedback if not isinstance(feedback, list) else feedback[0] 86 | prompt_id = f"{isample}" 87 | 88 | return { 89 | 'prompt_id': prompt_id, 90 | 'messages': { 91 | 'prompt': prompt, 92 | 'completion': completion 93 | } 94 | }, None # Indicate valid result 95 | 96 | 97 | if __name__ == "__main__": 98 | # Load data from input files 99 | data = [] 100 | with open(args.input_file, 'r') as f: 101 | for line in f: 102 | data.append(json.loads(line)) 103 | 104 | raw_data = json.load(open(args.raw_train_file)) 105 | 106 | # Prepare arguments for each sample to process 107 | samples_args = [(i, data[i], raw_data[i], args.use_groundtruth, args.no_filter) for i in range(len(data))] 108 | 109 | # Run parallel processing with 24 processes 110 | sft_data = [] 111 | error_data = [] 112 | with Pool(24) as pool: 113 | for result, error in tqdm(pool.imap_unordered(process_sample, samples_args), total=len(data)): 114 | if result: 115 | sft_data.append(result) 116 | if error: 117 | error_data.append(error) 118 | # for sample_arg in tqdm(samples_args): 119 | # result, error = process_sample(sample_arg) 120 | # if result: 121 | # sft_data.append(result) 122 | # if error: 123 | # error_data.append(error) 124 | 125 | # Create datasets 126 | dataset = DatasetDict({ 127 | 'train': Dataset.from_list(sft_data), 128 | 'test': Dataset.from_list(sft_data[:100]), 129 | }) 130 | print(dataset) 131 | 132 | # Save the dataset 133 | dataset.save_to_disk(args.output_dir) 134 | 135 | # Write error data to JSONL file 136 | with open(args.error_file, 'w') as output_file: 137 | for sample in error_data: 138 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n') 139 | -------------------------------------------------------------------------------- /llm_alignment/build_rlef_selection_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle as pkl 3 | import argparse 4 | import numpy as np 5 | import requests 6 | from tqdm import tqdm 7 | from copy import deepcopy 8 | from multiprocessing import Pool, cpu_count 9 | from data_processing.planner import SelectionAgentWithSchema 10 | import os 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--pred_file", type=str, default='logs/results-orpo-iter-2-bird-train-top-20-temperature-1.0.pkl') 14 | parser.add_argument("--max_candidates", type=int, default=3) 15 | parser.add_argument("--progress_file", type=str, default='temp/bird_selection_dpo.jsonl') 16 | args = parser.parse_args() 17 | 18 | # Initialize Selection Agent 19 | selection_agent = SelectionAgentWithSchema() 20 | 21 | def get_answer_selection(messages): 22 | response = requests.post( 23 | "http://192.168.1.108:8006/v1/completions", 24 | json={ 25 | "model": 'selection', 26 | "prompt": messages[0]['content'], 27 | "max_tokens": 512, 28 | "use_beam_search": False, 29 | "n": 20, 30 | "temperature": 1.0, 31 | "stop": ['<|eot_id|>', '<|end|>', '<|end_header_id|>', '<|end_of_text|>', '<|end▁of▁sentence|>'] 32 | } 33 | ).json() 34 | 35 | try: 36 | return [x['text'] for x in response['choices']] 37 | except: 38 | print(response) 39 | return [] 40 | 41 | selection_agent.get_answer = get_answer_selection 42 | 43 | # Load predictions 44 | preds = pkl.load(open(args.pred_file, 'rb')) 45 | 46 | # Load progress from previous runs 47 | processed_keys = {} 48 | if os.path.exists(args.progress_file): 49 | with open(args.progress_file, 'r', encoding='utf-8') as f: 50 | for line in f: 51 | sample = json.loads(line.strip()) 52 | key = (sample["db_id"], sample["question"]) 53 | processed_keys[key] = processed_keys.get(key, 0) + 1 54 | 55 | # Expand preds 4 times and filter already processed ones 56 | all_preds = preds * 4 57 | filtered_preds = [] 58 | for sample in all_preds: 59 | key = (sample["db_id"], sample["question"]) 60 | if processed_keys.get(key, 0) < 4: 61 | filtered_preds.append(sample) 62 | processed_keys[key] = processed_keys.get(key, 0) + 1 # Track count 63 | 64 | def build_dpo_data(sample): 65 | """Process a single sample and return DPO data.""" 66 | sample = deepcopy(sample) 67 | 68 | # Filter out samples with execution failures 69 | valid_sqls, valid_results, valid_corrects = [], [], [] 70 | for i in range(min(len(sample['predict_sqls']), 20)): 71 | if 'Execution failed' not in sample['pred_results'][i] and 'too much time' not in sample['pred_results'][i]: 72 | valid_sqls.append(sample['predict_sqls'][i]) 73 | valid_results.append(sample['pred_results'][i]) 74 | valid_corrects.append(sample['is_execution_corrects'][i]) 75 | 76 | sample['predict_sqls'] = valid_sqls 77 | sample['pred_results'] = valid_results 78 | sample['is_execution_corrects'] = valid_corrects 79 | 80 | # Shuffle valid results 81 | indices = np.random.permutation(len(sample['predict_sqls'])).tolist() 82 | sample['predict_sqls'] = [sample['predict_sqls'][i] for i in indices] 83 | sample['pred_results'] = [sample['pred_results'][i] for i in indices] 84 | sample['is_execution_corrects'] = [sample['is_execution_corrects'][i] for i in indices] 85 | 86 | # Select a random number of candidates 87 | n_candidates = np.random.randint(2, 6) 88 | sample['predict_sqls'] = sample['predict_sqls'][:n_candidates] 89 | sample['pred_results'] = sample['pred_results'][:n_candidates] 90 | sample['is_execution_corrects'] = sample['is_execution_corrects'][:n_candidates] 91 | sample['candidate_sqls'] = sample['predict_sqls'] 92 | sample['candidate_pred_results'] = sample['pred_results'] 93 | 94 | # Generate prompt and answers 95 | prompt, answers = selection_agent.generate(sample) 96 | 97 | dpo_data = { 98 | 'db_path': sample['db_path'], 99 | 'db_id': sample['db_id'], 100 | 'question': sample['question'], 101 | 'sql': sample['sql'], 102 | 'true_result': str(sample['true_result']).strip(), 103 | 'predict_sqls': sample['predict_sqls'], 104 | 'pred_results': [str(x).strip() for x in sample['pred_results']], 105 | 'is_execution_corrects': sample['is_execution_corrects'], 106 | 'reward_data': [] 107 | } 108 | 109 | for answer in answers: 110 | answer_index = selection_agent.extract_answer_index(answer) 111 | 112 | if answer_index == -1 and sum(sample['is_execution_corrects']) > 0: 113 | reward = 0 114 | elif answer_index == -1 and sum(sample['is_execution_corrects']) == 0: 115 | reward = 1 116 | elif answer_index > len(sample['is_execution_corrects']): 117 | reward = 0 118 | elif answer_index > 0: 119 | reward = int(sample['is_execution_corrects'][answer_index - 1]) 120 | else: 121 | reward = -2 122 | 123 | dpo_data['reward_data'].append({ 124 | 'prompt': prompt, 125 | 'completion': answer, 126 | 'reward': reward 127 | }) 128 | 129 | return dpo_data 130 | 131 | if __name__ == "__main__": 132 | num_processes = min(32, cpu_count()) # Use up to 32 processes 133 | 134 | # Track progress and write every 50 samples 135 | processed_count = 0 136 | with Pool(num_processes) as pool, open(args.progress_file, 'a', encoding='utf-8') as f: 137 | for dpo_data in tqdm(pool.imap_unordered(build_dpo_data, filtered_preds), total=len(filtered_preds)): 138 | f.write(json.dumps(dpo_data, ensure_ascii=False) + "\n") 139 | processed_count += 1 140 | 141 | # Save every 50 samples 142 | if processed_count % 50 == 0: 143 | f.flush() 144 | -------------------------------------------------------------------------------- /alignment-handbook/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import pytest 18 | from datasets import Dataset 19 | 20 | from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer 21 | 22 | 23 | class GetDatasetsTest(unittest.TestCase): 24 | """Each of these test datasets has 100 examples""" 25 | 26 | def test_loading_data_args(self): 27 | dataset_mixer = { 28 | "HuggingFaceH4/testing_alpaca_small": 0.5, 29 | "HuggingFaceH4/testing_self_instruct_small": 0.3, 30 | "HuggingFaceH4/testing_codealpaca_small": 0.2, 31 | } 32 | data_args = DataArguments(dataset_mixer=dataset_mixer) 33 | datasets = get_datasets(data_args) 34 | self.assertEqual(len(datasets["train"]), 100) 35 | self.assertEqual(len(datasets["test"]), 300) 36 | 37 | def test_loading_data_dict(self): 38 | dataset_mixer = { 39 | "HuggingFaceH4/testing_alpaca_small": 0.5, 40 | "HuggingFaceH4/testing_self_instruct_small": 0.3, 41 | "HuggingFaceH4/testing_codealpaca_small": 0.2, 42 | } 43 | datasets = get_datasets(dataset_mixer) 44 | self.assertEqual(len(datasets["train"]), 100) 45 | self.assertEqual(len(datasets["test"]), 300) 46 | 47 | def test_loading_with_unit_fractions(self): 48 | dataset_mixer = { 49 | "HuggingFaceH4/testing_alpaca_small": 1.0, 50 | "HuggingFaceH4/testing_self_instruct_small": 1.0, 51 | "HuggingFaceH4/testing_codealpaca_small": 1.0, 52 | } 53 | datasets = get_datasets(dataset_mixer) 54 | self.assertEqual(len(datasets["train"]), 300) 55 | self.assertEqual(len(datasets["test"]), 300) 56 | 57 | def test_loading_with_fractions_greater_than_unity(self): 58 | dataset_mixer = { 59 | "HuggingFaceH4/testing_alpaca_small": 0.7, 60 | "HuggingFaceH4/testing_self_instruct_small": 0.4, 61 | } 62 | datasets = get_datasets(dataset_mixer) 63 | self.assertEqual(len(datasets["train"]), 70 + 40) 64 | self.assertEqual(len(datasets["test"]), 200) 65 | 66 | def test_loading_fails_with_negative_fractions(self): 67 | dataset_mixer = { 68 | "HuggingFaceH4/testing_alpaca_small": 0.7, 69 | "HuggingFaceH4/testing_self_instruct_small": -0.3, 70 | } 71 | with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."): 72 | get_datasets(dataset_mixer) 73 | 74 | def test_loading_single_split_with_unit_fractions(self): 75 | dataset_mixer = { 76 | "HuggingFaceH4/testing_alpaca_small": 1.0, 77 | } 78 | datasets = get_datasets(dataset_mixer, splits=["test"]) 79 | self.assertEqual(len(datasets["test"]), 100) 80 | self.assertRaises(KeyError, lambda: datasets["train"]) 81 | 82 | 83 | class ApplyChatTemplateTest(unittest.TestCase): 84 | def setUp(self): 85 | model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") 86 | data_args = DataArguments() 87 | self.tokenizer = get_tokenizer(model_args, data_args) 88 | self.dataset = Dataset.from_dict( 89 | { 90 | "prompt": ["Hello!"], 91 | "messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], 92 | "chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], 93 | "rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]], 94 | } 95 | ) 96 | 97 | def test_sft(self): 98 | dataset = self.dataset.map( 99 | apply_chat_template, 100 | fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"}, 101 | remove_columns=self.dataset.column_names, 102 | ) 103 | self.assertDictEqual( 104 | dataset[0], 105 | {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n"}, 106 | ) 107 | 108 | def test_generation(self): 109 | # Remove last turn from messages 110 | dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]}) 111 | dataset = dataset.map( 112 | apply_chat_template, 113 | fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"}, 114 | remove_columns=self.dataset.column_names, 115 | ) 116 | self.assertDictEqual( 117 | dataset[0], 118 | {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n"}, 119 | ) 120 | 121 | def test_rm(self): 122 | dataset = self.dataset.map( 123 | apply_chat_template, 124 | fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"}, 125 | remove_columns=self.dataset.column_names, 126 | ) 127 | self.assertDictEqual( 128 | dataset[0], 129 | { 130 | "text_chosen": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n", 131 | "text_rejected": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nHola!\n", 132 | }, 133 | ) 134 | 135 | def test_dpo(self): 136 | dataset = self.dataset.map( 137 | apply_chat_template, 138 | fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"}, 139 | remove_columns=self.dataset.column_names, 140 | ) 141 | self.assertDictEqual( 142 | dataset[0], 143 | { 144 | "text_prompt": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n", 145 | "text_chosen": "Bonjour!\n", 146 | "text_rejected": "Hola!\n", 147 | }, 148 | ) 149 | -------------------------------------------------------------------------------- /validator_data/generate_fixed_sql_using_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | import os 4 | import multiprocessing.pool 5 | import functools 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from utils import get_columns_in_select_clause 9 | 10 | def timeout(max_timeout): 11 | """Timeout decorator, parameter in seconds.""" 12 | def timeout_decorator(item): 13 | """Wrap the original function.""" 14 | @functools.wraps(item) 15 | def func_wrapper(*args, **kwargs): 16 | """Closure for function.""" 17 | pool = multiprocessing.pool.ThreadPool(processes=1) 18 | async_result = pool.apply_async(item, args, kwargs) 19 | # raises a TimeoutError if execution exceeds max_timeout 20 | return async_result.get(max_timeout) 21 | return func_wrapper 22 | return timeout_decorator 23 | 24 | @timeout(30) 25 | def _execute_sql_with_timeout(db_path, action): 26 | conn = sqlite3.connect(db_path) 27 | conn.text_factory = lambda b: b.decode(errors="ignore") 28 | actions = action.split(";") 29 | actions = [x for x in actions if len(x.strip()) > 0] 30 | if len(actions) == 0: 31 | return "no SQL query executed.", True 32 | cursor = conn.cursor() 33 | for action in actions: 34 | # action = action.lower() 35 | try: 36 | cursor.execute(action) 37 | response = cursor.fetchall() 38 | has_error = False 39 | except Exception as error: 40 | # If the SQL query is invalid, return error message from sqlite 41 | response = str(error) 42 | has_error = True 43 | cursor.close() 44 | break 45 | cursor.close() 46 | conn.close() 47 | return response, has_error 48 | 49 | def _execute_sql(db_path, sql_query): 50 | try: 51 | pred_result, has_error = _execute_sql_with_timeout(db_path, sql_query) 52 | except: 53 | pred_result = "The query takes too much time." 54 | has_error = True 55 | return pred_result, has_error 56 | 57 | def _make_str_response(response, has_error): 58 | if has_error: 59 | return str(response) 60 | else: 61 | df = pd.DataFrame(response) 62 | return str(df) 63 | 64 | # PROMPT = open('./few_shot_prompt_fix.txt').read() + """========= 65 | # {schema} 66 | 67 | # Matched contents are written in this format table.column (some values can be found in that column) 68 | # {matched_content} 69 | 70 | # Question: {question} 71 | 72 | # SQL query: {sql_query} 73 | 74 | # Execution response [written in pandas format]: 75 | # {execution_response} 76 | 77 | # Feedback:{feedback} 78 | 79 | # FIXED SQL:""" 80 | 81 | PROMPT = open('./few_shot_prompt_fix.txt').read().strip() + """ 82 | ========= 83 | {schema} 84 | 85 | Matched contents are written in this format table.column (some values can be found in that column) 86 | {matched_content} 87 | 88 | Question: {question} 89 | 90 | SQL query: {sql_query} 91 | 92 | Feedback:{feedback} 93 | 94 | FIXED SQL:""" 95 | 96 | 97 | from openai import OpenAI 98 | 99 | client = OpenAI( 100 | api_key='no-key', 101 | base_url='http://localhost:8000/v1' 102 | ) 103 | 104 | # def get_answer(messages): 105 | # response = client.chat.completions.create( 106 | # model='codeS', 107 | # messages=messages, 108 | # max_tokens=2048, 109 | # temperature=0.0, 110 | # # eos_token_id=self.tokenizer.convert_tokens_to_ids(['<|end|>']) 111 | # ) 112 | # response = response.choices[0].message.content.strip() 113 | # return response 114 | 115 | # def get_answer(messages): 116 | # response = client.completions.create( 117 | # model='meta-llama/Meta-Llama-3.1-8B-Instruct/', 118 | # prompt=messages[0]['content'], 119 | # max_tokens=256, 120 | # temperature=0.0, 121 | # stop=['========='] 122 | # # eos_token_id=self.tokenizer.convert_tokens_to_ids(['<|end|>']) 123 | # ) 124 | # response = response.choices[0].text 125 | # return response 126 | 127 | def get_answer(messages): 128 | import requests 129 | response = requests.post("http://localhost:8000/v1/completions", 130 | json={ 131 | "model": "meta-llama/Meta-Llama-3.1-8B-Instruct/", 132 | "prompt": messages[0]['content'], 133 | "max_tokens": 256, 134 | "use_beam_search": True, 135 | "n": 4, 136 | "temperature": 0, 137 | "stop": ["========="] 138 | }).json() 139 | return response["choices"][0]["text"] 140 | 141 | data = json.load(open('./bird_validator_select.json')) 142 | output_file = './bird_fixed_sql.json' 143 | 144 | # data = json.load(open('../temp/codes/temp/codes/eval_codes-1b.json')) 145 | # output_file = 'bird_dev_validator_select.json' 146 | 147 | for isample in tqdm(range(0, len(data)), total=len(data)): 148 | sample = data[isample] 149 | 150 | sql = sample['predict_sql'] 151 | is_correct = sample['is_correct'] 152 | if sample['validator_select'] is None or "Conclude: correct" in sample['validator_select']: 153 | continue 154 | 155 | prompt = PROMPT.format( 156 | schema=sample['schema_sequence'], 157 | matched_content=sample['content_sequence'], 158 | question=sample['text'], 159 | sql_query=sql, 160 | # execution_response=sample['pred_result'], 161 | feedback=sample['validator_select'] 162 | ) 163 | # print(prompt) 164 | answer = get_answer([{"role": "user", "content": prompt}]) 165 | 166 | execution_result = _execute_sql("../" + sample['db_path'], answer) 167 | 168 | print("-"*20) 169 | print(answer) 170 | # break 171 | sample['fixed_sql'] = answer 172 | sample['fixed_pred_result'] = _make_str_response(*execution_result) 173 | 174 | json.dump(data[:isample+1], open(output_file, 'w+'), ensure_ascii=False, indent=4) 175 | json.dump(data[:isample+1], open(output_file, 'w+'), ensure_ascii=False, indent=4) 176 | 177 | bird_results_dict = dict() 178 | for idx, sample in enumerate(data): 179 | if 'fixed_sql' in sample: 180 | predicted_sql = sample['fixed_sql'] 181 | else: 182 | predicted_sql = sample['predict_sql'] 183 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"] 184 | with open("predict_dev.json", "w", encoding = 'utf-8') as f: 185 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False)) 186 | --------------------------------------------------------------------------------