├── .env
├── alignment-handbook
├── tests
│ ├── __init__.py
│ ├── fixtures
│ │ ├── config_dpo_full.yaml
│ │ └── config_sft_full.yaml
│ ├── test_configs.py
│ ├── test_model_utils.py
│ └── test_data.py
├── README.md
├── chapters
│ └── en
│ │ ├── chapter0
│ │ └── introduction.mdx
│ │ └── _toctree.yml
├── assets
│ └── handbook.png
├── scripts
│ ├── setup.sh
│ ├── test_gamma_bird.sh
│ ├── train_spider.sh
│ └── train_bird.sh
├── src
│ └── alignment
│ │ ├── __init__.py
│ │ ├── release.py
│ │ └── model_utils.py
├── recipes
│ ├── accelerate_configs
│ │ ├── multi_gpu.yaml
│ │ ├── zero2.yaml
│ │ ├── deepspeed_dpo.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ ├── no_offload_optimizer.yaml
│ │ └── deepspeed_original.yaml
│ ├── llama-1b-bird
│ │ ├── fixed-fft.yaml
│ │ ├── selection-fft.yaml
│ │ ├── validator-fixer-fft.yaml
│ │ ├── orpo-fixed.yaml
│ │ ├── orpo-validator-fixed.yaml
│ │ ├── llama-3-bird-planner-fft.yaml
│ │ ├── validator-fft.yaml
│ │ ├── llama-3-bird-validator-fft.yaml
│ │ └── orpo-validator.yaml
│ ├── llama-3b-bird
│ │ ├── selection-fft.yaml
│ │ ├── validator-fixer-fft.yaml
│ │ ├── orpo-selection.yaml
│ │ ├── orpo-fixed.yaml
│ │ ├── orpo-planner-iter-2.yaml
│ │ ├── orpo-planner-iter-3.yaml
│ │ ├── orpo-fixed-iter-2.yaml
│ │ ├── orpo-planner.yaml
│ │ ├── planner-fft.yaml
│ │ ├── fixed-fft.yaml
│ │ ├── orpo-validator-iter-2.yaml
│ │ ├── llama-3-bird-validator-fft.yaml
│ │ ├── orpo-llama-3-validator.yaml
│ │ └── orpo-validator.yaml
│ ├── llama-1b-spider
│ │ ├── planner-fft.yaml
│ │ ├── orpo-fixed.yaml
│ │ ├── fixed-fft.yaml
│ │ ├── validator-fft.yaml
│ │ └── orpo-validator.yaml
│ └── llama-3b-spider
│ │ ├── planner-fft.yaml
│ │ ├── orpo-fixed.yaml
│ │ ├── llama-3-fixed-fft.yaml
│ │ ├── orpo-planner-iter-2.yaml
│ │ ├── orpo-planner.yaml
│ │ ├── orpo-planner-iter-3.yaml
│ │ ├── orpo-validator.yaml
│ │ └── fft-validator.yaml
├── .github
│ └── workflows
│ │ ├── upload_pr_documentation.yml
│ │ ├── build_documentation.yml
│ │ ├── build_pr_documentation.yml
│ │ ├── quality.yml
│ │ └── tests.yml
├── setup.cfg
├── Makefile
├── .gitignore
└── setup.py
├── validator_data
├── .env
├── generate_validator_order_using_fewshot.py
├── generate_validator_select_using_fewshot.py
├── generate_fixed_sql_using_fewshot_condition.py
├── generate_validator_join_using_fewshot.py
├── generate_fixed_sql_using_fewshot_join.py
├── generate_validator_condition_using_fewshot.py
├── utils.py
└── generate_fixed_sql_using_fewshot.py
├── requirements.txt
├── .gitignore
├── json2jsonl.py
├── jsonl2json.py
├── visualization
├── temperature_sensitivity.py
├── planner_prompt_comparison.py
├── sql_characteristic_spider.py
├── parameter_sensitivity_num_candidates.py
├── sql_characteristic_bird.py
├── domain_knowledge.py
├── rlef_improvement.py
├── .ipynb_checkpoints
│ └── rlef_improvement-checkpoint.py
└── lambda_sensitivity.py
├── llm_alignment
├── merge_rl_data.py
└── build_rlef_selection_data.py
├── scripts
├── evaluate_bird.sh
└── evaluate_dr_spider.sh
├── utils
├── classifier_loss.py
├── load_classifier_dataset.py
├── load_pt_dataset.py
├── lr_scheduler.py
└── load_sft_dataset.py
├── data_processing
├── generate_validator_data.py
├── prompts
│ └── zero_shot_prompt_planner.txt
├── generate_planner_data.py
├── merge_val_fix_data.ipynb
├── generate_validator_fixer_data.py
├── generate_sft_data_for_validator.py
├── generate_sft_data_for_fix.py
└── generate_sft_data_for_planner.py
├── bird_evaluation
└── run_evaluation.sh
└── README.md
/.env:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=
2 |
--------------------------------------------------------------------------------
/alignment-handbook/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/validator_data/.env:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=
2 |
--------------------------------------------------------------------------------
/alignment-handbook/README.md:
--------------------------------------------------------------------------------
1 | To train SFT on TinyLLama, run:
2 | ```
3 | bash scripts/train_tinyllama.sh
4 | ```
5 |
6 |
--------------------------------------------------------------------------------
/alignment-handbook/chapters/en/chapter0/introduction.mdx:
--------------------------------------------------------------------------------
1 | # Welcome to the RLHF Handbook!
2 |
3 | Stay tuned for more details 🤗
--------------------------------------------------------------------------------
/alignment-handbook/assets/handbook.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thanhdath/mats-sql/HEAD/alignment-handbook/assets/handbook.png
--------------------------------------------------------------------------------
/alignment-handbook/chapters/en/_toctree.yml:
--------------------------------------------------------------------------------
1 | - title: Unit 0. Welcome to the RLHF Handbook!
2 | sections:
3 | - local: chapter0/introduction
4 | title: What is this about?
--------------------------------------------------------------------------------
/alignment-handbook/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | apt-get install git-lfs
2 | python -m pip install .
3 | #export PATH=/usr/local/cuda/bin:$PATH
4 | #export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
5 | python -m pip install flash-attn --no-build-isolation
6 | pip install git+https://github.com/huggingface/trl.git
7 |
--------------------------------------------------------------------------------
/alignment-handbook/src/alignment/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.0.dev0"
2 |
3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
4 | from .data import apply_chat_template, get_datasets
5 | from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | func_timeout==4.3.5
2 | nltk==3.7
3 | numpy==1.23.5
4 | pandas==2.0.1
5 | rapidfuzz==2.0.11
6 | tqdm==4.63.0
7 | transformers==4.28.1
8 | sqlparse==0.4.2
9 | accelerate==0.18.0
10 | bitsandbytes==0.41.1
11 | pyserini==0.21.0
12 | sql_metadata==2.8.0
13 | datasets==2.11.0
14 | faiss-cpu==1.7.4
15 | deepspeed==0.9.5
16 | tensorboard
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | pred_sqls.txt
3 | predict_dev.json
4 | output/
5 | test_suite_sql_eval
6 | seeklhy/
7 | sic_ckpts/
8 | data/
9 | log-tensorboard/
10 | *.json
11 | temp/
12 | *.log
13 | *.csv
14 | db_content_retrieval/volumes
15 | offload/
16 | *.xlsx
17 | test_suite_sql_eval
18 | alignment-handbook/thanhdath/
19 | progress.jsonl
20 | logs/
21 | temp.sh
22 | *.pkl
23 | *.jsonl
24 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/multi_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 2
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | offload_optimizer_device: cpu
4 | offload_param_device: none
5 | zero3_init_flag: true
6 | zero_stage: 2
7 | distributed_type: DEEPSPEED
8 | fsdp_config: {}
9 | machine_rank: 0
10 | main_process_ip: null
11 | main_process_port: null
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 2
16 | use_cpu: false
--------------------------------------------------------------------------------
/json2jsonl.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--json-file', type=str, required=True)
6 | args = parser.parse_args()
7 |
8 | file = args.json_file
9 |
10 | # Read the JSON file
11 | with open(file, 'r') as f:
12 | data = json.load(f)
13 |
14 | # Export to JSONL
15 | with open(file.replace('.json', '.jsonl'), 'w') as f:
16 | for record in data:
17 | f.write(json.dumps(record) + '\n')
18 |
--------------------------------------------------------------------------------
/jsonl2json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--jsonl-file', type=str, required=True)
6 | args = parser.parse_args()
7 |
8 | file = args.jsonl_file
9 | data = []
10 | with open(file, 'r') as f:
11 | for line in f:
12 | data.append(json.loads(line))
13 | # export to json, replace jsonl to json
14 | with open(file.replace('jsonl', 'json'), 'w') as f:
15 | json.dump(data, f, indent=4)
16 |
--------------------------------------------------------------------------------
/alignment-handbook/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: alignment-handbook
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/deepspeed_dpo.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: ./recipes/accelerate_configs/dpo.json
5 | zero3_init_flag: true
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 2
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: ./recipes/accelerate_configs/ds_7b.json
5 | zero3_init_flag: true
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 2
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/no_offload_optimizer.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: ./recipes/accelerate_configs/ds_7b.json
5 | zero3_init_flag: true
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 2
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/alignment-handbook/scripts/test_gamma_bird.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=src/
2 | export CUDA_VISIBLE_DEVICES=0,1
3 |
4 | for beta in 0.0 0.25 0.5 1.0 0.75
5 | do
6 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-3b-bird/orpo-planner.yaml --model_name_or_path=./output/reproduce/llama-3b-bird-planner-fft-no-filter --output_dir=output/param-sensitivity/orpo-llama-3b-bird-planner-beta-$beta --save_strategy=no
7 | done
8 |
--------------------------------------------------------------------------------
/alignment-handbook/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
11 | with:
12 | commit_sha: ${{ github.sha }}
13 | package: alignment-handbook
14 | path_to_docs: alignment-handbook/chapters/
15 | additional_args: --not_python_module
16 | languages: en
17 | secrets:
18 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
19 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/accelerate_configs/deepspeed_original.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 2
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
--------------------------------------------------------------------------------
/alignment-handbook/scripts/train_spider.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | export PYTHONPATH=src/
3 | #accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_sft.py recipes/llama-3b-spider/planner-fft.yaml
4 | accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 1 scripts/run_orpo.py recipes/llama-3b-spider/orpo-planner-iter-3.yaml
5 | # accelerate launch --main_process_port 29502 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_sft.py recipes/llama-3b-spider/sql-gt-fft.yaml
6 |
7 |
8 |
--------------------------------------------------------------------------------
/alignment-handbook/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 |
6 | concurrency:
7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.event.pull_request.head.sha }}
15 | pr_number: ${{ github.event.number }}
16 | package: alignment-handbook
17 | path_to_docs: alignment-handbook/chapters/
18 | additional_args: --not_python_module
19 | languages: en
--------------------------------------------------------------------------------
/alignment-handbook/.github/workflows/quality.yml:
--------------------------------------------------------------------------------
1 | name: Quality
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - v*-release
8 | pull_request:
9 | branches:
10 | - main
11 |
12 | jobs:
13 |
14 | check_code_quality:
15 | name: Check code quality
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v2
20 | - name: Setup Python environment
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: 3.10.10
24 | - name: Install dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | python -m pip install ".[quality]"
28 | - name: Code quality
29 | run: |
30 | make quality
31 |
32 |
--------------------------------------------------------------------------------
/alignment-handbook/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - v*-release
8 | pull_request:
9 | branches:
10 | - main
11 |
12 | jobs:
13 |
14 | unit-tests:
15 | name: Run unit tests
16 | env:
17 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
18 | runs-on: ubuntu-latest
19 | steps:
20 | - name: Checkout code
21 | uses: actions/checkout@v2
22 | - name: Setup Python environment
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: 3.10.10
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install ".[dev, torch]"
30 | - name: Run unit tests
31 | run: HF_TOKEN=$HF_TOKEN pytest -sv tests/
--------------------------------------------------------------------------------
/alignment-handbook/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = alignment
7 | known_third_party =
8 | transformers
9 | datasets
10 | fugashi
11 | git
12 | h5py
13 | matplotlib
14 | nltk
15 | numpy
16 | packaging
17 | pandas
18 | psutil
19 | pytest
20 | rouge_score
21 | sacrebleu
22 | seqeval
23 | sklearn
24 | streamlit
25 | torch
26 | tqdm
27 |
28 | line_length = 119
29 | lines_after_imports = 2
30 | multi_line_output = 3
31 | use_parentheses = True
32 |
33 | [flake8]
34 | ignore = E203, E501, E741, W503, W605
35 | max-line-length = 119
36 | per-file-ignores =
37 | # imported but unused
38 | __init__.py: F401
39 |
40 | [tool:pytest]
41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/visualization/temperature_sensitivity.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | # Given data
4 | x_values = [0, 0.3, 0.5, 1.0] # Temperature values
5 | y_values = [59.12, 63.58, 63.98, 64.13] # Mean values (A)
6 | y_errors = [0.09, 0.13, 0.23, 0.25] # Standard deviation (B)
7 |
8 | # Adjust figure size for 1-column fit in a 2-column research paper (~3.5 inches wide)
9 | plt.figure(figsize=(3.5, 2.5))
10 |
11 | # Create error bar plot
12 | plt.errorbar(x_values, y_values, yerr=y_errors, fmt='o-', capsize=3, capthick=1)
13 |
14 | # Labels with optimized font size for readability
15 | plt.xlabel("Temperature", fontsize=9)
16 | plt.ylabel("EX%", fontsize=9)
17 | plt.xticks(x_values, fontsize=8)
18 | plt.yticks(fontsize=8)
19 | plt.grid(True, linestyle="--", alpha=0.7)
20 |
21 | # Tight layout for better fit in a research paper column
22 | plt.tight_layout()
23 |
24 | # Show the plot
25 | plt.show()
26 |
--------------------------------------------------------------------------------
/alignment-handbook/tests/fixtures/config_dpo_full.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: alignment-handbook/zephyr-7b-sft-full
3 |
4 | # Data training arguments
5 | # For definitions, see: src/h4/training/config.py
6 | dataset_mixer:
7 | HuggingFaceH4/ultrafeedback_binarized: 1.0
8 | dataset_splits:
9 | - train_prefs
10 | - test_prefs
11 | preprocessing_num_workers: 12
12 |
13 | # DPOTrainer arguments
14 | bf16: true
15 | beta: 0.1
16 | do_eval: true
17 | evaluation_strategy: steps
18 | eval_steps: 100
19 | gradient_accumulation_steps: 1
20 | gradient_checkpointing: true
21 | hub_model_id: zephyr-7b-dpo-full
22 | learning_rate: 5.0e-7
23 | log_level: info
24 | logging_steps: 10
25 | lr_scheduler_type: linear
26 | max_length: 1024
27 | max_prompt_length: 512
28 | num_train_epochs: 3
29 | optim: rmsprop
30 | output_dir: data/zephyr-7b-dpo-full
31 | per_device_train_batch_size: 8
32 | per_device_eval_batch_size: 4
33 | push_to_hub: true
34 | save_strategy: "no"
35 | save_total_limit: null
36 | seed: 42
37 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/alignment-handbook/tests/fixtures/config_sft_full.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: mistralai/Mistral-7B-v0.1
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 |
7 | # Data training arguments
8 | dataset_mixer:
9 | HuggingFaceH4/ultrachat_200k: 1.0
10 | dataset_splits:
11 | - train_sft
12 | - test_sft
13 | preprocessing_num_workers: 12
14 |
15 | # SFT trainer config
16 | bf16: true
17 | do_eval: true
18 | evaluation_strategy: epoch
19 | gradient_accumulation_steps: 2
20 | gradient_checkpointing: true
21 | hub_model_id: zephyr-7b-sft-full
22 | hub_strategy: every_save
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 2048
29 | max_steps: -1
30 | num_train_epochs: 1
31 | output_dir: data/zephyr-7b-sft-full
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 16
34 | per_device_train_batch_size: 32
35 | push_to_hub: true
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "no"
40 | save_total_limit: null
41 | seed: 42
--------------------------------------------------------------------------------
/alignment-handbook/scripts/train_bird.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=src/
2 | export CUDA_VISIBLE_DEVICES=0,1
3 | # ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 1 scripts/run_orpo.py recipes/llama-3b-bird/orpo-planner.yaml --model_name_or_path=./output/llama-3b-bird-planner-fft --output_dir=output/reproduce/orpo-llama-3b-bird-planner
4 |
5 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-bird/orpo-validator.yaml
6 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-bird/orpo-fixed.yaml
7 |
8 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-spider/orpo-validator.yaml
9 | ACCELERATE_LOG_LEVEL=info accelerate launch --main_process_port 29504 --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes 2 scripts/run_orpo.py recipes/llama-1b-spider/orpo-fixed.yaml
--------------------------------------------------------------------------------
/visualization/planner_prompt_comparison.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | # Data from the table
5 | strategies = ["No Thought", "Chain-of-Thought", "Few-shot Thoughts"]
6 | categories = ["simple", "moderate", "challenging", "overall"]
7 |
8 | # Execution accuracy scores
9 | scores = np.array([
10 | [59.46, 37.5, 30.34, 50.07],
11 | [60.11, 40.09, 37.93, 51.96],
12 | [63.03, 43.75, 38.62, 54.89]
13 | ])
14 |
15 | # Bar width and positions
16 | bar_width = 0.25
17 | x = np.arange(len(categories))
18 |
19 | # Creating figure with aspect ratio suitable for a research paper (single-column)
20 | fig, ax = plt.subplots(figsize=(4.5, 2.5))
21 |
22 | # Plot bars for each strategy
23 | for i, strategy in enumerate(strategies):
24 | ax.bar(x + i * bar_width, scores[i], width=bar_width, label=strategy)
25 |
26 | # Labels and formatting
27 | ax.set_xticks(x + bar_width)
28 | ax.set_xticklabels(categories, fontsize=9)
29 | ax.set_ylabel("EX%", fontsize=9)
30 |
31 | # Move legend to the top
32 | ax.legend(fontsize=8, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, frameon=True)
33 |
34 | # Tight layout for better fit
35 | plt.tight_layout()
36 |
37 | # Show the plot
38 | plt.show()
39 |
--------------------------------------------------------------------------------
/alignment-handbook/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: style quality
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := src tests scripts
7 |
8 | style:
9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py
10 | isort $(check_dirs) setup.py
11 |
12 | quality:
13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14 | isort --check-only $(check_dirs) setup.py
15 | flake8 --max-line-length 119 $(check_dirs) setup.py
16 |
17 |
18 | # Release stuff
19 |
20 | pre-release:
21 | python src/alignment/release.py
22 |
23 | pre-patch:
24 | python src/alignment/release.py --patch
25 |
26 | post-release:
27 | python src/alignment/release.py --post_release
28 |
29 | post-patch:
30 | python src/alignment/release.py --post_release --patch
31 |
32 | wheels:
33 | python setup.py bdist_wheel && python setup.py sdist
34 |
35 | wheels_clean:
36 | rm -rf build && rm -rf dist
37 |
38 | pypi_upload:
39 | python -m pip install twine
40 | twine upload dist/* -r pypi
41 |
42 | pypi_test_upload:
43 | python -m pip install twine
44 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
45 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/fixed-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | dataset_mixer:
9 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-bird_with_evidence: 1
10 | dataset_splits:
11 | - train
12 | - test
13 | preprocessing_num_workers: 24
14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
15 |
16 | # SFT trainer config
17 | bf16: true
18 | do_eval: false
19 | evaluation_strategy: "no"
20 | gradient_accumulation_steps: 64
21 | gradient_checkpointing: true
22 | hub_strategy: every_save
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 4
31 | output_dir: output/llama-1b-bird-fixed-fft
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 1
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "no"
40 | save_total_limit: 1
41 | seed: 42
42 | tf32: true
43 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/selection-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/selection/sft_bird: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: false
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 64
22 | gradient_checkpointing: true
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 1
31 | output_dir: output/llama-1b-bird-selection-fft
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 1
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "steps"
40 | save_steps: 200
41 | save_total_limit: 10
42 | seed: 42
43 | tf32: true
44 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/selection-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/selection/sft_ranking_bird: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: false
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 32
22 | gradient_checkpointing: true
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 2
31 | output_dir: output/llama-3b-bird-selection-fft
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 2
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "steps"
40 | save_steps: 100
41 | save_total_limit: 1
42 | seed: 42
43 | tf32: true
44 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/validator-fixer-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | dataset_mixer:
9 | ../data/multi-agents/fixed/sft-validator-fixer-bird_with_evidence/: 1
10 | dataset_splits:
11 | - train
12 | - test
13 | preprocessing_num_workers: 24
14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
15 |
16 | # SFT trainer config
17 | bf16: true
18 | do_eval: false
19 | evaluation_strategy: "no"
20 | gradient_accumulation_steps: 64
21 | gradient_checkpointing: true
22 | hub_strategy: every_save
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 4
31 | output_dir: output/llama-1b-bird-validator-fixer-fft
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 1
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "no"
40 | save_total_limit: 1
41 | seed: 42
42 | tf32: true
43 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/validator-fixer-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | dataset_mixer:
9 | ../data/multi-agents/fixed/sft-validator-fixer-bird_with_evidence/: 1
10 | dataset_splits:
11 | - train
12 | - test
13 | preprocessing_num_workers: 24
14 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
15 |
16 | # SFT trainer config
17 | bf16: true
18 | do_eval: false
19 | evaluation_strategy: "no"
20 | gradient_accumulation_steps: 64
21 | gradient_checkpointing: true
22 | hub_strategy: every_save
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 4
31 | output_dir: output/llama-3b-bird-validator-fixer-fft
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 1
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "no"
40 | save_total_limit: 1
41 | seed: 42
42 | tf32: true
43 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-spider/planner-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|reserved_special_token_247|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/planner/sft-gpt-4o-mini-planner_spider_train: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{messages['prompt'] + '<|reserved_special_token_247|>\n'}}{{messages['completion'] + '<|end_of_text|>'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: "no"
21 | gradient_accumulation_steps: 8
22 | gradient_checkpointing: true
23 | hub_model_id: griffith-bigdata/llama-1b-spider-planner-fft
24 | hub_strategy: every_save
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 2048
31 | max_steps: -1
32 | num_train_epochs: 1
33 | output_dir: output/llama-1b-spider-planner-fft-1epoch
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 8
36 | per_device_train_batch_size: 8
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "epoch"
42 | save_steps: 20
43 | save_total_limit: 1
44 | seed: 42
45 | tf32: true
46 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-selection.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-3b-bird-selection-cot-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p1-selection: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 24
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 8
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: cosine
33 | max_length: 2300
34 | max_prompt_length: 1700
35 | num_train_epochs: 1
36 | max_steps: -1
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3b-bird-selection
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 200
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.05
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/planner-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/planner/sft-gpt-4o-mini-planner_spider_train_no_filter/: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: "no"
21 | gradient_accumulation_steps: 32
22 | gradient_checkpointing: true
23 | learning_rate: 5.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 2048
29 | max_steps: -1
30 | num_train_epochs: 4
31 | output_dir: output/reproduce/llama-3b-spider-planner-fft-no-filter
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 4
34 | per_device_train_batch_size: 2
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "steps"
40 | save_steps: 20
41 | save_total_limit: 4
42 | seed: 42
43 | tf32: true
44 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/orpo-fixed.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-1b-bird-fixed-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p1-fix/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | # chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 8
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 3300
34 | max_prompt_length: 3000
35 | num_train_epochs: 1
36 | max_steps: 500
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-1b-fixed-bird/
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.01
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-spider/orpo-fixed.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-1b-spider-fixed-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/spider-p1-fix/dpo-llama-3-end2end-spider_train_dev_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 16
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 5.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 2500
34 | max_prompt_length: 2000
35 | num_train_epochs: -1
36 | max_steps: 500
37 | optim: adamw_torch
38 | output_dir: output/orpo-llama-1b-fixed-spider
39 | per_device_train_batch_size: 2
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.1
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-fixed.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-3b-bird-fixed-fft-follow-validation/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p1-fix/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 16
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 3300
34 | max_prompt_length: 3000
35 | num_train_epochs: 1
36 | max_steps: 500
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3-fixed-bird/
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.01
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-spider/fixed-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-spider/: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: "no"
21 | gradient_accumulation_steps: 64
22 | gradient_checkpointing: true
23 | hub_model_id: griffith-bigdata/llama-1b-spider-fixed-fft
24 | hub_strategy: every_save
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 4096
31 | max_steps: -1
32 | num_train_epochs: 2
33 | output_dir: output/llama-1b-spider-fixed-fft
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 2
36 | per_device_train_batch_size: 2
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "no"
42 | save_total_limit: 1
43 | seed: 42
44 | tf32: true
45 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-planner-iter-2.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/reproduce/orpo-llama-3b-bird-planner
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p2-planner/dpo-llama-3-end2end-bird_train_planner: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "no"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 8
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: cosine
33 | max_length: 2300
34 | max_prompt_length: 1700
35 | num_train_epochs: 1
36 | max_steps: -1
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3b-iter-2-bird-planner
39 | per_device_train_batch_size: 2
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.05
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/orpo-fixed.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-3b-spider-fixed-fft-follow-validation/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/spider-p1-fix/dpo-llama-3-end2end-spider_train_dev_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 0.25
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 16
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 5.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 2500
34 | max_prompt_length: 2000
35 | num_train_epochs: -1
36 | max_steps: 800
37 | optim: adamw_torch
38 | output_dir: output/orpo-llama-3-fixed-spider
39 | per_device_train_batch_size: 2
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.1
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-planner-iter-3.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/reproduce/orpo-llama-3b-iter-2-bird-planner
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p3-planner/dpo-llama-3-end2end-bird_train_planner: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "no"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 16
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: cosine
33 | max_length: 2300
34 | max_prompt_length: 1700
35 | num_train_epochs: 1
36 | max_steps: -1
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3b-iter-3-bird-planner
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "no"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.05
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-fixed-iter-2.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: output/reproduce/orpo-llama-3-fixed-bird/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird/bird-p2-fix/dpo-llama-3-end2end-bird_train_dev_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 0.25
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 8
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 3300
34 | max_prompt_length: 3000
35 | num_train_epochs: -1
36 | max_steps: 500
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3-iter-2-fixed-bird/
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.1
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-planner.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/reproduce/llama-3b-bird-planner-fft-no-filter
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p1-planner/dpo-llama-3-end2end-bird_train_planner: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 24
15 |
16 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 1.0
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 8
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: cosine
33 | max_length: 2300
34 | max_prompt_length: 1700
35 | num_train_epochs: 1
36 | max_steps: -1
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-3b-bird-planner-no-filter
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 1
45 | seed: 42
46 | warmup_ratio: 0.05
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/planner-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_with_true_sql_bird_with_evidence_train_no_filter/: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 32
22 | gradient_checkpointing: true
23 | learning_rate: 2.0e-05
24 | log_level: info
25 | logging_steps: 5
26 | logging_strategy: steps
27 | lr_scheduler_type: cosine
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 4
31 | output_dir: output/reproduce/llama-3b-bird-planner-fft-no-filter
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 2
34 | per_device_train_batch_size: 2
35 | push_to_hub: false
36 | remove_unused_columns: true
37 | report_to:
38 | - tensorboard
39 | save_strategy: "epoch"
40 | save_steps: 40
41 | save_total_limit: 1
42 | seed: 42
43 | tf32: true
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/fixed-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-bird_with_evidence: 1
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 64
22 | gradient_checkpointing: true
23 | hub_model_id: griffith-bigdata/llama-3b-bird-fixed-fft
24 | hub_strategy: every_save
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 4096
31 | max_steps: -1
32 | num_train_epochs: 2
33 | output_dir: output/llama-3b-bird-fixed-fft-follow-validation
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 2
36 | per_device_train_batch_size: 1
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "no"
42 | save_total_limit: 1
43 | seed: 42
44 | tf32: true
45 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/llama-3-fixed-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/fixed/sft-fixed-spider/: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: "no"
21 | gradient_accumulation_steps: 32
22 | gradient_checkpointing: true
23 | hub_model_id: griffith-bigdata/llama-3b-spider-fixed-fft
24 | hub_strategy: every_save
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 4096
31 | max_steps: -1
32 | num_train_epochs: 2
33 | output_dir: output/llama-3b-spider-fixed-fft-follow-validation
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 2
36 | per_device_train_batch_size: 2
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "no"
42 | save_total_limit: 1
43 | seed: 42
44 | tf32: true
45 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/orpo-validator-fixed.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-1b-bird-validator-fixer-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird-p1-validator-fixer/dpo-llama-3-end2end-bird_train_fixed_sql: 1.0
10 |
11 | dataset_splits:
12 | - train_dpo
13 | - test_dpo
14 | preprocessing_num_workers: 12
15 |
16 | # chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
17 | report_to: ["tensorboard"]
18 |
19 | # DPOTrainer arguments
20 | bf16: true
21 | beta: 100
22 | do_eval: true
23 | eval_strategy: "steps"
24 | eval_steps: 100
25 | gradient_accumulation_steps: 16
26 | gradient_checkpointing: true
27 | gradient_checkpointing_kwargs:
28 | use_reentrant: False
29 | learning_rate: 8.0e-6
30 | log_level: info
31 | logging_steps: 10
32 | lr_scheduler_type: inverse_sqrt
33 | max_length: 3300
34 | max_prompt_length: 2000
35 | num_train_epochs: 2
36 | max_steps: 100000
37 | optim: adamw_torch
38 | output_dir: output/reproduce/orpo-llama-1b-validator-fixer-bird-beta0.5/
39 | per_device_train_batch_size: 1
40 | per_device_eval_batch_size: 1
41 | push_to_hub: false
42 | save_strategy: "steps"
43 | save_steps: 100
44 | save_total_limit: 2
45 | seed: 42
46 | warmup_ratio: 0.01
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/orpo-planner-iter-2.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: output/reproduce/orpo-3b-spider-planner
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | # For definitions, see: src/h4/training/config.py
9 | dataset_mixer:
10 | ../data/llm_alignment/spider-p2-planner/dpo-llama-3-end2end-spider_train_planner: 1.0
11 |
12 | dataset_splits:
13 | - train_dpo
14 | - test_dpo
15 | preprocessing_num_workers: 24
16 |
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 1.0
23 | do_eval: true
24 | eval_strategy: "no"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 4
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | learning_rate: 8.0e-6
31 | log_level: info
32 | logging_steps: 10
33 | lr_scheduler_type: cosine
34 | max_length: 1600
35 | max_prompt_length: 1200
36 | num_train_epochs: 1
37 | max_steps: -1
38 | optim: adamw_torch
39 | output_dir: output/reproduce/orpo-3b-spider-planner-iter-2
40 | per_device_train_batch_size: 2
41 | per_device_eval_batch_size: 2
42 | push_to_hub: false
43 | save_strategy: "epoch"
44 | save_steps: 100
45 | save_total_limit: 2
46 | seed: 42
47 | warmup_ratio: 0.05
48 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/llama-3-bird-planner-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_with_true_sql_bird_062024_with_evidence_train/: 1.0
11 | dataset_splits:
12 | - train
13 | - test
14 | preprocessing_num_workers: 24
15 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
16 |
17 | # SFT trainer config
18 | bf16: true
19 | do_eval: true
20 | evaluation_strategy: "no"
21 | gradient_accumulation_steps: 32
22 | gradient_checkpointing: true
23 | hub_model_id: griffith-bigdata/llama-1b-bird-planner-fft
24 | hub_strategy: every_save
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 8192
31 | max_steps: -1
32 | num_train_epochs: 4
33 | output_dir: output/llama-1b-bird-planner-fft
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 4
36 | per_device_train_batch_size: 2
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "no"
42 | save_total_limit: 1
43 | seed: 42
44 | tf32: true
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/orpo-planner.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: output/reproduce/llama-3b-spider-planner-fft-no-filter
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | # For definitions, see: src/h4/training/config.py
9 | dataset_mixer:
10 | ../data/llm_alignment/spider-p1-planner/dpo-llama-3-end2end-spider_train_dev_planner: 1.0
11 |
12 | dataset_splits:
13 | - train_dpo
14 | - test_dpo
15 | preprocessing_num_workers: 24
16 |
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 1.0
23 | do_eval: true
24 | eval_strategy: "no"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 4
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | learning_rate: 8.0e-6
31 | log_level: info
32 | logging_steps: 10
33 | lr_scheduler_type: cosine
34 | max_length: 1600
35 | max_prompt_length: 1200
36 | num_train_epochs: 1
37 | max_steps: -1
38 | optim: adamw_torch
39 | output_dir: output/reproduce/orpo-3b-spider-planner
40 | per_device_train_batch_size: 2
41 | per_device_eval_batch_size: 2
42 | push_to_hub: false
43 | save_strategy: "epoch"
44 | save_steps: 100
45 | save_total_limit: 2
46 | seed: 42
47 | warmup_ratio: 0.05
48 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/orpo-planner-iter-3.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: output/reproduce/orpo-3b-spider-planner-iter-2
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
6 |
7 | # Data training arguments
8 | # For definitions, see: src/h4/training/config.py
9 | dataset_mixer:
10 | ../data/llm_alignment/spider-p3-planner/dpo-llama-3-end2end-spider_train_planner: 1.0
11 |
12 | dataset_splits:
13 | - train_dpo
14 | - test_dpo
15 | preprocessing_num_workers: 24
16 |
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 1.0
23 | do_eval: true
24 | eval_strategy: "no"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 8
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | learning_rate: 8.0e-6
31 | log_level: info
32 | logging_steps: 10
33 | lr_scheduler_type: cosine
34 | max_length: 1600
35 | max_prompt_length: 1200
36 | num_train_epochs: 1
37 | max_steps: -1
38 | optim: adamw_torch
39 | output_dir: output/reproduce/orpo-3b-spider-planner-iter-3
40 | per_device_train_batch_size: 2
41 | per_device_eval_batch_size: 2
42 | push_to_hub: false
43 | save_strategy: "epoch"
44 | save_steps: 100
45 | save_total_limit: 2
46 | seed: 42
47 | warmup_ratio: 0.05
48 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-spider/validator-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/validator/sft-validator_select_spider/: 1.0
11 | ../data/multi-agents/validator/sft-validator_join_spider/: 1.0
12 | ../data/multi-agents/validator/sft-validator_condition_spider/: 1.0
13 | dataset_splits:
14 | - train
15 | - test
16 | preprocessing_num_workers: 24
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 |
19 | # SFT trainer config
20 | bf16: true
21 | do_eval: true
22 | evaluation_strategy: "no"
23 | gradient_accumulation_steps: 64
24 | gradient_checkpointing: true
25 | learning_rate: 2.0e-05
26 | log_level: info
27 | logging_steps: 5
28 | logging_strategy: steps
29 | lr_scheduler_type: cosine
30 | max_seq_length: 2048
31 | max_steps: -1
32 | num_train_epochs: 4
33 | output_dir: output/llama-1b-spider-validator-fft
34 | overwrite_output_dir: true
35 | per_device_eval_batch_size: 4
36 | per_device_train_batch_size: 2
37 | push_to_hub: false
38 | remove_unused_columns: true
39 | report_to:
40 | - tensorboard
41 | save_strategy: "epoch"
42 | save_steps: 100
43 | save_total_limit: 1
44 | seed: 42
45 | tf32: true
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/orpo-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-3b-spider-validator-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_select/: 1.0
10 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_condition/: 1.0
11 |
12 | dataset_splits:
13 | - train_dpo
14 | - test_dpo
15 | preprocessing_num_workers: 12
16 |
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 0.25
23 | do_eval: true
24 | eval_strategy: "no"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 16
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | hub_model_id: th-dpo
31 | learning_rate: 8.0e-6
32 | log_level: info
33 | logging_steps: 10
34 | lr_scheduler_type: cosine
35 | max_length: 1600
36 | max_prompt_length: 1200
37 | num_train_epochs: 1
38 | max_steps: -1
39 | optim: adamw_torch
40 | output_dir: output/reproduce/orpo-llama-3-validator-spider
41 | per_device_train_batch_size: 1
42 | per_device_eval_batch_size: 1
43 | push_to_hub: false
44 | save_strategy: "steps"
45 | save_steps: 100
46 | save_total_limit: 1
47 | seed: 42
48 | warmup_ratio: 0.05
49 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-spider/orpo-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-1b-spider-validator-fft/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_select/: 1.0
10 | ../data/llm_alignment/spider-p1-validator/dpo-llama-3-end2end-spider_train_validator_condition/: 1.0
11 |
12 | dataset_splits:
13 | - train_dpo
14 | - test_dpo
15 | preprocessing_num_workers: 12
16 |
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 1.0
23 | do_eval: true
24 | eval_strategy: "no"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 8
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | hub_model_id: th-dpo
31 | learning_rate: 8.0e-6
32 | log_level: info
33 | logging_steps: 10
34 | lr_scheduler_type: cosine
35 | max_length: 1600
36 | max_prompt_length: 1200
37 | # num_train_epochs: 1
38 | max_steps: 500
39 | optim: adamw_torch
40 | output_dir: output/reproduce/orpo-llama-1b-validator-spider
41 | per_device_train_batch_size: 1
42 | per_device_eval_batch_size: 1
43 | push_to_hub: false
44 | save_strategy: "steps"
45 | save_steps: 100
46 | save_total_limit: 1
47 | seed: 42
48 | warmup_ratio: 0.05
49 |
--------------------------------------------------------------------------------
/llm_alignment/merge_rl_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_from_disk
2 | import argparse
3 | import datasets
4 | import numpy as np
5 | from datasets import Dataset
6 |
7 | # add arguments data/llm_alignment/spider-p1
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--data_dir", type=str, help="path to the data directory")
10 | args = parser.parse_args()
11 |
12 | # read all train data from the data directory, including
13 | # dpo-llama-3-end2end-spider_train_fixed_sql
14 | # dpo-llama-3-end2end-spider_train_planner
15 | # dpo-llama-3-end2end-spider_train_validator_condition
16 | # dpo-llama-3-end2end-spider_train_validator_join
17 | # dpo-llama-3-end2end-spider_train_validator_select
18 | # dpo-llama-3-end2end-spider_train_validator_order
19 |
20 | import glob
21 | import os
22 | data_dirs = glob.glob(args.data_dir + "/*train*")
23 | data_dirs = [x for x in data_dirs if os.path.isdir(x)]
24 | print(data_dirs)
25 |
26 | for data_dir in data_dirs:
27 | dataset_train = load_from_disk(data_dir)
28 | # load dev data
29 | dev_file = data_dir.replace("train", "dev")
30 | if os.path.exists(dev_file):
31 | dataset_dev = load_from_disk(dev_file)
32 | dataset_dev = list(dataset_dev['train_dpo'])
33 | dataset_dev = np.random.permutation(dataset_dev)[:2000].tolist()
34 | dataset_train['test_dpo'] = Dataset.from_list(dataset_dev)
35 |
36 | print(data_dir)
37 | print(dataset_train)
38 |
39 | # save the merged data to other directory
40 | dataset_train.save_to_disk(data_dir.replace("train", "train_dev"))
41 | print(data_dir.replace("train", "train_dev"))
42 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-validator-iter-2.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/reproduce/orpo-llama-3-validator-bird/
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_select/: 1.0
10 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0
11 | ../data/llm_alignment/bird/bird-p2-validator/dpo-llama-3-end2end-bird_train_validator_join/: 1.0
12 |
13 | dataset_splits:
14 | - train_dpo
15 | - test_dpo
16 | preprocessing_num_workers: 12
17 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
18 | report_to: ["tensorboard"]
19 |
20 | # DPOTrainer arguments
21 | bf16: true
22 | beta: 0.25
23 | do_eval: true
24 | eval_strategy: "steps"
25 | eval_steps: 100
26 | gradient_accumulation_steps: 8
27 | gradient_checkpointing: true
28 | gradient_checkpointing_kwargs:
29 | use_reentrant: False
30 | learning_rate: 8.0e-6
31 | log_level: info
32 | logging_steps: 10
33 | lr_scheduler_type: inverse_sqrt
34 | max_length: 2600
35 | max_prompt_length: 2000
36 | num_train_epochs: -1
37 | max_steps: 500
38 | optim: adamw_torch
39 | output_dir: output/reproduce/orpo-iter-2-llama-3-validator-bird/
40 | per_device_train_batch_size: 1
41 | per_device_eval_batch_size: 1
42 | push_to_hub: false
43 | save_strategy: "steps"
44 | save_steps: 100
45 | save_total_limit: 1
46 | seed: 42
47 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-spider/fft-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | ../data/multi-agents/validator/sft-validator_select_spider/: 1.0
11 | ../data/multi-agents/validator/sft-validator_join_spider/: 1.0
12 | ../data/multi-agents/validator/sft-validator_condition_spider/: 1.0
13 | ../data/multi-agents/validator/sft-validator_order_spider/: 1.0
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 24
18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
19 |
20 | # SFT trainer config
21 | bf16: true
22 | do_eval: true
23 | evaluation_strategy: "no"
24 | gradient_accumulation_steps: 32
25 | gradient_checkpointing: true
26 | hub_model_id: griffith-bigdata/llama-3b-spider-validator-fft
27 | hub_strategy: every_save
28 | learning_rate: 2.0e-05
29 | log_level: info
30 | logging_steps: 5
31 | logging_strategy: steps
32 | lr_scheduler_type: cosine
33 | max_seq_length: 2048
34 | max_steps: -1
35 | num_train_epochs: 4
36 | output_dir: output/llama-3b-spider-validator-fft
37 | overwrite_output_dir: true
38 | per_device_eval_batch_size: 4
39 | per_device_train_batch_size: 2
40 | push_to_hub: false
41 | remove_unused_columns: true
42 | report_to:
43 | - tensorboard
44 | save_strategy: "epoch"
45 | save_steps: 100
46 | save_total_limit: 3
47 | seed: 42
48 | tf32: true
--------------------------------------------------------------------------------
/visualization/sql_characteristic_spider.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 |
5 | # Sample dataset (Ensure to replace with actual data)
6 | data = {
7 | "Method": ["MATS (Ours)", "DAILSQL(SC)", "CodeS-15B", "CodeS-7B", "REDSQL-3B\n+NatSQL", "REDSQL-3B", "Graphix\n+PICARD"],
8 | "w/o JOIN": [92.49, 89.1, 90.6, 89.6, 89.0, 90.1, 88.3],
9 | "w/ JOIN": [79.9, 75.0, 76.2, 78.9, 76.7, 69.1, 69.6],
10 | "w/o Subquery": [88.85, 84.2, 86.0, 86.7, 85.8, 83.2, 82.1],
11 | "w/ Subquery": [72.29, 63.6, 51.5, 45.5, 33.3, 39.4, 45.5],
12 | "w/o Logical\nConnector": [88.98, 85.3, 86.4, 87.0, 85.8, 83.9, 83.1],
13 | "w/ Logical\nConnector": [72.22, 65.6, 68.9, 68.9, 66.7, 60.0, 58.9],
14 | "w/o ORDER-BY": [88.08, 84.3, 85.1, 86.3, 83.6, 81.7, 80.9],
15 | "w/ ORDER-BY": [85.65, 81.0, 84.4, 82.3, 86.1, 82.3, 81.0],
16 | "Overall": [87.1, 83.6, 84.9, 85.4, 84.1, 81.8, 80.9]
17 | }
18 |
19 | # Convert to DataFrame
20 | df = pd.DataFrame(data)
21 | df.set_index("Method", inplace=True)
22 |
23 | # Transpose DataFrame to swap axes
24 | df = df.T
25 |
26 | # Remove duplicates by stripping subset names
27 | df.index = df.index.str.strip()
28 | df = df.loc[~df.index.duplicated(keep='first')]
29 |
30 | # Set up the figure size
31 | plt.figure(figsize=(4.5, 3.5))
32 |
33 | # Create the heatmap
34 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False)
35 |
36 | # Labels
37 | plt.xlabel("", fontsize=8)
38 | plt.ylabel("Subset", fontsize=8)
39 |
40 | # Rotate x-axis labels for better readability
41 | plt.xticks(rotation=90, ha="right", fontsize=6)
42 | plt.yticks(fontsize=6)
43 |
44 | # Show the plot
45 | plt.tight_layout()
46 | plt.show()
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/validator-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1
11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1
12 | /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1
13 | # /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 24
18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
19 |
20 | # SFT trainer config
21 | bf16: true
22 | do_eval: false
23 | evaluation_strategy: "no"
24 | gradient_accumulation_steps: 32
25 | gradient_checkpointing: true
26 | hub_model_id: griffith-bigdata/llama-3b-bird-validator-fft
27 | hub_strategy: every_save
28 | learning_rate: 2.0e-05
29 | log_level: info
30 | logging_steps: 5
31 | logging_strategy: steps
32 | lr_scheduler_type: cosine
33 | max_seq_length: 4096
34 | max_steps: -1
35 | num_train_epochs: 4
36 | output_dir: output/llama-1b-bird-validator-fft
37 | overwrite_output_dir: true
38 | per_device_eval_batch_size: 2
39 | per_device_train_batch_size: 4
40 | push_to_hub: false
41 | remove_unused_columns: true
42 | report_to:
43 | - tensorboard
44 | save_strategy: "epoch"
45 | save_total_limit: 1
46 | seed: 42
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/llama-3-bird-validator-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | # /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1
11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1
12 | # /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1
13 | # /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 24
18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
19 |
20 | # SFT trainer config
21 | bf16: true
22 | do_eval: true
23 | evaluation_strategy: "no"
24 | gradient_accumulation_steps: 32
25 | gradient_checkpointing: true
26 | hub_model_id: griffith-bigdata/llama-3b-bird-validator-fft
27 | hub_strategy: every_save
28 | learning_rate: 2.0e-05
29 | log_level: info
30 | logging_steps: 5
31 | logging_strategy: steps
32 | lr_scheduler_type: cosine
33 | max_seq_length: 4096
34 | max_steps: -1
35 | num_train_epochs: 4
36 | output_dir: output/llama-3b-bird-validator-fft-2
37 | overwrite_output_dir: true
38 | per_device_eval_batch_size: 2
39 | per_device_train_batch_size: 2
40 | push_to_hub: false
41 | remove_unused_columns: true
42 | report_to:
43 | - tensorboard
44 | save_strategy: "no"
45 | save_total_limit: 1
46 | seed: 42
47 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/llama-3-bird-validator-fft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments:
2 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-1B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 | response_template: "<|start_header_id|>assistant<|end_header_id|>"
7 |
8 | # Data training arguments
9 | dataset_mixer:
10 | /home/datht/codes/data/multi-agents/validator/sft-validator_select_bird_with_evidence/: 1
11 | /home/datht/codes/data/multi-agents/validator/sft-validator_condition_bird_with_evidence/: 1
12 | /home/datht/codes/data/multi-agents/validator/sft-validator_join_bird_with_evidence/: 1
13 | /home/datht/codes/data/multi-agents/validator/sft-validator_order_bird_with_evidence/: 1
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 24
18 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
19 |
20 | # SFT trainer config
21 | bf16: true
22 | do_eval: true
23 | evaluation_strategy: "no"
24 | gradient_accumulation_steps: 32
25 | gradient_checkpointing: true
26 | hub_model_id: griffith-bigdata/llama-1b-bird-validator-fft
27 | hub_strategy: every_save
28 | learning_rate: 2.0e-05
29 | log_level: info
30 | logging_steps: 5
31 | logging_strategy: steps
32 | lr_scheduler_type: cosine
33 | max_seq_length: 8192
34 | max_steps: -1
35 | num_train_epochs: 4
36 | output_dir: output/llama-1b-bird-validator-fft
37 | overwrite_output_dir: true
38 | per_device_eval_batch_size: 2
39 | per_device_train_batch_size: 2
40 | push_to_hub: false
41 | remove_unused_columns: true
42 | report_to:
43 | - tensorboard
44 | save_strategy: "no"
45 | save_total_limit: 1
46 | seed: 42
47 | tf32: true
48 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-llama-3-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | # model_name_or_path: ./output/llama-3b-bird-validator-fft-1epoch/
3 | model_name_or_path: /home/datht/huggingface/meta-llama/Llama-3.2-3B-Instruct
4 | torch_dtype: bfloat16
5 | use_flash_attention_2: true
6 |
7 | # Data training arguments
8 | # For definitions, see: src/h4/training/config.py
9 | dataset_mixer:
10 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_select/: 1.0
11 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0
12 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_join/: 1.0
13 | ../data/llm_alignment/bird-p1-validator/dpo-llama-3-end2end-bird_train_validator_order/: 1.0
14 |
15 |
16 | dataset_splits:
17 | - train_dpo
18 | - test_dpo
19 | preprocessing_num_workers: 12
20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
21 | report_to: ["tensorboard"]
22 |
23 | # DPOTrainer arguments
24 | bf16: true
25 | beta: 0.25
26 | do_eval: true
27 | eval_strategy: "no"
28 | eval_steps: 40
29 | gradient_accumulation_steps: 32
30 | gradient_checkpointing: true
31 | gradient_checkpointing_kwargs:
32 | use_reentrant: False
33 | hub_model_id: th-dpo
34 | learning_rate: 5.0e-6
35 | log_level: info
36 | logging_steps: 10
37 | lr_scheduler_type: cosine
38 | max_length: 2600
39 | max_prompt_length: 2000
40 | num_train_epochs: 3
41 | optim: adamw_torch
42 | output_dir: output/orpo-llama-3-validator-bird
43 | per_device_train_batch_size: 1
44 | per_device_eval_batch_size: 1
45 | push_to_hub: false
46 | save_strategy: "steps"
47 | save_steps: 40
48 | save_total_limit: 1
49 | seed: 42
50 | warmup_ratio: 0.1
51 |
--------------------------------------------------------------------------------
/visualization/parameter_sensitivity_num_candidates.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | # Data for plotting
5 | candidates = np.array([1, 5, 10, 20, 30])
6 | achieved_accuracy = np.array([59.17, 63.75, 64.73, 62.91, 62.78])
7 | upper_bound = np.array([59.17, 69.6, 72.7, 76, 77.8])
8 | lower_bound = np.array([59.17, 59.17, 59.17, 59.17, 59.17])
9 |
10 | # Define figure size to fit a research paper column
11 | fig_width = 4 # Adjusted for single-column fit
12 | fig_height = fig_width * 0.75 # Maintain aspect ratio
13 |
14 | # Create the figure
15 | plt.figure(figsize=(fig_width, fig_height), dpi=300) # High DPI for publication quality
16 |
17 | # Plot the curves
18 | plt.plot(candidates, upper_bound, label="Upper Bound", linestyle="--", color="red")
19 | plt.plot(candidates, achieved_accuracy, label="MATS", marker="o", color="blue")
20 | plt.plot(candidates, lower_bound, label="Lower Bound", linestyle="--", color="green")
21 |
22 | # Fill between (shading) without adding legend
23 | plt.fill_between(candidates, achieved_accuracy, upper_bound, color="gray", alpha=0.3)
24 | plt.fill_between(candidates, lower_bound, achieved_accuracy, color="gray", alpha=0.3)
25 |
26 | # Labels and formatting
27 | plt.xlabel("Number of Candidates", fontsize=10)
28 | plt.ylabel(r"EX\%", fontsize=10) # LaTeX-style notation for EX%
29 | plt.xticks(candidates, fontsize=9)
30 | plt.yticks(fontsize=9)
31 | plt.legend(fontsize=9, loc="upper left") # Move legend to top left
32 | plt.grid(True, linewidth=0.5)
33 |
34 | # Remove unnecessary borders for a cleaner publication look
35 | plt.gca().spines["top"].set_visible(False)
36 | plt.gca().spines["right"].set_visible(False)
37 |
38 | # Save the figure in a high-quality format for LaTeX insertion
39 | plt.savefig("accuracy_vs_bounds.pdf", bbox_inches="tight", format="pdf")
40 |
41 | # Show the figure
42 | plt.show()
43 |
--------------------------------------------------------------------------------
/visualization/sql_characteristic_bird.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 |
5 | # Define the dataset with all extracted EX (%) values including MATS (Ours)
6 | data = {
7 | "Subset": [
8 | "w/o JOIN", "w/ JOIN", "w/o Subquery", "w/ Subquery",
9 | "w/o Logical\nConnector", "w/ Logical\nConnector",
10 | "w/o ORDER-BY", "w/ ORDER-BY", "Overall"
11 | ],
12 | "MATS (Ours)": [68.02, 59.21, 63.12, 40.71, 65.33, 56.04, 63.67, 52.75, 64.74],
13 | "DAILSQL(SC)": [61.4, 53.9, 56.9, 37.9, 59.1, 51.3, 58.3, 46.3, 55.9],
14 | # "DAILSQL": [60.4, 52.2, 55.4, 35.6, 56.6, 51.0, 57.1, 43.0, 54.3],
15 | # "C3SQL": [55.3, 48.4, 51.2, 33.3, 54.5, 44.1, 53.5, 37.2, 50.2],
16 | "CodeS-15B": [63.5, 56.8, 59.8, 36.8, 62.2, 53.2, 61.1, 48.2, 58.5],
17 | "CodeS-7B": [63.2, 54.8, 58.5, 32.2, 60.6, 51.8, 59.6, 46.6, 57.0],
18 | # "SFT CodeS-3B": [61.2, 52.7, 56.3, 31.0, 59.8, 48.0, 57.9, 43.0, 54.9],
19 | # "SFT CodeS-1B": [57.1, 47.9, 51.6, 28.7, 55.3, 43.2, 53.4, 37.9, 50.3],
20 | "REDSQL-3B": [52.0, 41.1, 44.7, 31.0, 49.4, 36.3, 47.2, 31.1, 43.9],
21 | "REDSQL-L Large": [45.9, 36.1, 39.6, 21.8, 45.7, 28.6, 41.9, 25.6, 38.6],
22 | "REDSQL-L Base": [40.9, 30.4, 33.9, 19.5, 38.7, 25.3, 35.5, 23.6, 33.1],
23 |
24 | }
25 |
26 | # Convert to DataFrame and set index
27 | df = pd.DataFrame(data)
28 | df.set_index("Subset", inplace=True)
29 |
30 | # Create a figure with a single heatmap
31 | fig = plt.figure(figsize=(4.5, 3.5))
32 |
33 | # Create the heatmap
34 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False)
35 |
36 | plt.set_xlabel("")
37 | plt.set_ylabel("Subset", fontsize=8)
38 | plt.set_xticklabels(plt.get_xticklabels(), rotation=90, ha="right", fontsize=6)
39 | plt.set_yticklabels(plt.get_yticklabels(), fontsize=6)
40 |
41 | # Adjust layout
42 | plt.tight_layout()
43 |
44 | # Show the plot
45 | plt.show()
46 |
--------------------------------------------------------------------------------
/alignment-handbook/tests/test_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import unittest
17 |
18 | from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig
19 |
20 |
21 | class H4ArgumentParserTest(unittest.TestCase):
22 | def setUp(self):
23 | self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
24 | self.yaml_file_path = "tests/fixtures/config_sft_full.yaml"
25 |
26 | def test_load_yaml(self):
27 | model_args, data_args, training_args = self.parser.parse_yaml_file(os.path.abspath(self.yaml_file_path))
28 | self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1")
29 |
30 | def test_load_yaml_and_args(self):
31 | command_line_args = [
32 | "--model_name_or_path=test",
33 | "--use_peft=true",
34 | "--lora_r=16",
35 | "--lora_dropout=0.5",
36 | ]
37 | model_args, data_args, training_args = self.parser.parse_yaml_and_args(
38 | os.path.abspath(self.yaml_file_path), command_line_args
39 | )
40 | self.assertEqual(model_args.model_name_or_path, "test")
41 | self.assertEqual(model_args.use_peft, True)
42 | self.assertEqual(model_args.lora_r, 16)
43 | self.assertEqual(model_args.lora_dropout, 0.5)
44 |
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-3b-bird/orpo-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-3b-bird-validator-fft
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_select/: 1.0
10 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_condition/: 1.0
11 | # ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_join/: 1.0
12 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_select/: 1.0
13 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0
14 | # ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_join/: 1.0
15 |
16 | dataset_splits:
17 | - train_dpo
18 | - test_dpo
19 | preprocessing_num_workers: 12
20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
21 | report_to: ["tensorboard"]
22 |
23 | # DPOTrainer arguments
24 | bf16: true
25 | beta: 1.0
26 | do_eval: true
27 | eval_strategy: "steps"
28 | eval_steps: 100
29 | gradient_accumulation_steps: 8
30 | gradient_checkpointing: true
31 | gradient_checkpointing_kwargs:
32 | use_reentrant: False
33 | learning_rate: 8.0e-6
34 | log_level: info
35 | logging_steps: 10
36 | lr_scheduler_type: inverse_sqrt
37 | max_length: 2600
38 | max_prompt_length: 2000
39 | num_train_epochs: 1
40 | max_steps: -1
41 | optim: adamw_torch
42 | output_dir: output/reproduce/orpo-llama-3-validator-bird/
43 | per_device_train_batch_size: 1
44 | per_device_eval_batch_size: 1
45 | push_to_hub: false
46 | save_strategy: "steps"
47 | save_steps: 100
48 | save_total_limit: 1
49 | seed: 42
50 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/alignment-handbook/recipes/llama-1b-bird/orpo-validator.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: ./output/llama-1b-bird-validator-fft
3 | torch_dtype: bfloat16
4 | use_flash_attention_2: true
5 |
6 | # Data training arguments
7 | # For definitions, see: src/h4/training/config.py
8 | dataset_mixer:
9 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_select/: 1.0
10 | ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_condition/: 1.0
11 | # ../data/llm_alignment/bird/bird-p1-validator/dpo-llama-3-end2end-bird_train_dev_validator_join/: 1.0
12 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_select/: 1.0
13 | ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_condition/: 1.0
14 | # ../data/llm_alignment/bird-p1-validator-modify/dpo-llama-3-end2end-bird_train_validator_join/: 1.0
15 |
16 | dataset_splits:
17 | - train_dpo
18 | - test_dpo
19 | preprocessing_num_workers: 12
20 | chat_template: "{{'<|start_header_id|>user<|end_header_id|>\n' + messages['prompt'] + '<|eot_id|>\n'}}{{'<|start_header_id|>assistant<|end_header_id|>\n' + messages['completion'] + '<|eot_id|>\n'}}"
21 | report_to: ["tensorboard"]
22 |
23 | # DPOTrainer arguments
24 | bf16: true
25 | beta: 1.0
26 | do_eval: true
27 | eval_strategy: "steps"
28 | eval_steps: 100
29 | gradient_accumulation_steps: 8
30 | gradient_checkpointing: true
31 | gradient_checkpointing_kwargs:
32 | use_reentrant: False
33 | learning_rate: 8.0e-6
34 | log_level: info
35 | logging_steps: 10
36 | lr_scheduler_type: inverse_sqrt
37 | max_length: 2600
38 | max_prompt_length: 2000
39 | num_train_epochs: -1
40 | max_steps: 600
41 | optim: adamw_torch
42 | output_dir: output/reproduce/orpo-llama-1b-validator-bird/
43 | per_device_train_batch_size: 1
44 | per_device_eval_batch_size: 1
45 | push_to_hub: false
46 | save_strategy: "steps"
47 | save_steps: 100
48 | save_total_limit: 1
49 | seed: 42
50 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/scripts/evaluate_bird.sh:
--------------------------------------------------------------------------------
1 |
2 | # CUDA_VISIBLE_DEVICES=0 vllm serve orpo-llama-3b-iter-2-bird-planner-no-filter/ --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.9 --enable-prefix-caching
3 |
4 | # CUDA_VISIBLE_DEVICES=1 vllm serve llama-1b-bird-validator-fft/ --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.5 --enable-prefix-caching
5 |
6 | # CUDA_VISIBLE_DEVICES=1 vllm serve llama-1b-bird-fixed-fft/ --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.4 --enable-prefix-caching
7 |
8 |
9 | mkdir logs
10 | mkdir logs/bird-dev
11 | # # Evaluate
12 | n=10
13 | temperature=1.0
14 | seed=100
15 | eval_file=data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.jsonl
16 | rm $eval_file
17 | PYTHONPATH=. python evaluate_end2end.py \
18 | --input_file data/full_value_matching_schema_insight_bird_062024_with_evidence_dev_text2sql.json \
19 | --output_file $eval_file \
20 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --only_planner --seed 100 --n_processes 32
21 |
22 | python jsonl2json.py --jsonl-file data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.jsonl
23 |
24 | # Evaluate Selection
25 | eval_file=data/evaluate/orpo-llama-3-iter-2-selection-bird_dev-greedy-and-sampling-seed$seed.jsonl
26 | rm $eval_file
27 | PYTHONPATH=. python evaluate_end2end.py \
28 | --input_file data/evaluate/orpo-llama-3-iter-2-planner-bird_dev-greedy-and-sampling-seed$seed.json \
29 | --output_file $eval_file \
30 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 \
31 | --skip_planner --seed 100 --n_processes 8 --skip_validator_join
32 |
33 | python compute_acc.py --pred_file $eval_file
34 |
35 |
--------------------------------------------------------------------------------
/utils/classifier_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # CrossEntropyLoss = softmax + log + NLLLoss
6 |
7 | class FocalLoss(nn.Module):
8 | def __init__(self, weight=None, gamma=0.5, reduction=None):
9 | super(FocalLoss, self).__init__()
10 |
11 | self.weight = weight
12 | self.gamma = gamma
13 | self.reduction = reduction
14 |
15 | def forward(self, input_tensor, target_tensor):
16 | assert input_tensor.shape[0] == target_tensor.shape[0]
17 |
18 | prob = F.softmax(input_tensor, dim = -1)
19 | log_prob = torch.log(prob + 1e-8)
20 |
21 | loss = F.nll_loss(
22 | ((1 - prob) ** self.gamma) * log_prob,
23 | target_tensor,
24 | weight=self.weight,
25 | reduction=self.reduction
26 | )
27 |
28 | return loss
29 |
30 | class ClassifierLoss():
31 | def __init__(self, alpha, gamma):
32 | weight = torch.FloatTensor([1-alpha, alpha])
33 | if torch.cuda.is_available():
34 | weight = weight.cuda()
35 |
36 | self.focal_loss = FocalLoss(
37 | weight = weight,
38 | gamma = gamma,
39 | reduction = 'mean'
40 | )
41 |
42 | # self.ce_loss = nn.CrossEntropyLoss(weight = weight, reduction = "mean")
43 |
44 | def compute_batch_loss(self, batch_logits, batch_labels, batch_size):
45 | loss = 0
46 | for logits, labels in zip(batch_logits, batch_labels):
47 | loss += self.focal_loss(logits, labels)
48 |
49 | return loss/batch_size
50 |
51 | def compute_loss(
52 | self,
53 | batch_table_name_cls_logits,
54 | batch_table_labels,
55 | batch_column_info_cls_logits,
56 | batch_column_labels
57 | ):
58 | batch_size = len(batch_table_labels)
59 |
60 | table_loss = self.compute_batch_loss(batch_table_name_cls_logits, batch_table_labels, batch_size)
61 | column_loss = self.compute_batch_loss(batch_column_info_cls_logits, batch_column_labels, batch_size)
62 |
63 | return table_loss + column_loss
--------------------------------------------------------------------------------
/data_processing/generate_validator_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sqlite3
3 | import multiprocessing.pool
4 | import functools
5 | from tqdm import tqdm
6 | import pandas as pd
7 | from validator import ValidatorJOIN, _execute_sql, _make_str_response, is_execution_correct
8 | import argparse
9 | import os
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json')
13 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl')
14 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai'])
15 | args = parser.parse_args()
16 |
17 | data = json.load(open(args.input_file))
18 |
19 | if os.path.exists(args.output_file):
20 | old_output = json.load(open(args.output_file))
21 | data[:len(old_output)] = old_output
22 | else:
23 | old_output = []
24 |
25 | # open jsonl file for append contents
26 | output_file = open(args.output_file, 'a+')
27 |
28 | validator = ValidatorJOIN(endpoint_type=args.endpoint_type)
29 |
30 | for isample in tqdm(range(0, len(data)), total=len(data)):
31 | sample = data[isample]
32 |
33 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql'])
34 |
35 | sql = sample['predict_sql']
36 |
37 | answer, execution_result = validator.validate(sample)
38 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0])
39 |
40 | print("-"*20)
41 | print("Is correct: ", is_correct)
42 | print(answer)
43 |
44 | sample['is_correct'] = is_correct
45 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer
46 | sample['validator_join'] = answer
47 |
48 | sample['true_result'] = _make_str_response(*true_execution_result)
49 | sample['pred_result'] = _make_str_response(*execution_result)
50 |
51 | del sample['table_labels']
52 | del sample['column_labels']
53 | del sample['schema']
54 | del sample['matched_contents']
55 |
56 | # json.dump(data[:isample+1], open(args.output_file, 'w+'), ensure_ascii=False, indent=4)
57 | # write new sample in jsonl file
58 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
59 |
--------------------------------------------------------------------------------
/scripts/evaluate_dr_spider.sh:
--------------------------------------------------------------------------------
1 | # n=9
2 | # temperature=1.0
3 | # seed=100
4 | # mkdir logs/
5 | # mkdir logs/dr-spider
6 | # for test_set in DB_schema_synonym DB_DBcontent_equivalence DB_schema_abbreviation NLQ_column_value SQL_DB_text SQL_DB_number SQL_comparison NLQ_column_carrier NLQ_multitype NLQ_others NLQ_keyword_synonym SQL_NonDB_number NLQ_column_synonym NLQ_keyword_carrier NLQ_value_synonym NLQ_column_attribute SQL_sort_order
7 | # do
8 | # eval_file=data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.jsonl
9 | # # rm $eval_file
10 | # PYTHONPATH=. python evaluate_end2end.py \
11 | # --input_file data/schema_insight_dr_spider_text2sql_${test_set}.json \
12 | # --output_file $eval_file \
13 | # --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --only_planner --seed $seed --n_processes 32
14 |
15 | # rm progress_dr.pkl
16 | # python -u check_correct_recall.py --pred_file $eval_file --progress_file progress_dr.pkl > logs/dr-spider/log-orpo-planner-dr_spider-$test_set-n$n-temp$temperature-seed$seed.txt
17 | # done
18 |
19 | n=9
20 | temperature=1.0
21 | seed=100
22 | mkdir logs/
23 | mkdir logs/dr-spider
24 | for test_set in SQL_comparison NLQ_column_carrier NLQ_multitype NLQ_others NLQ_keyword_synonym SQL_NonDB_number NLQ_column_synonym NLQ_keyword_carrier NLQ_value_synonym NLQ_column_attribute SQL_sort_order
25 | do
26 | python jsonl2json.py --jsonl-file data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.jsonl
27 |
28 | eval_file=data/evaluate/orpo-llama-3-iter-3-end2end-dr_spider_${test_set}.jsonl
29 | # rm $eval_file
30 | PYTHONPATH=. python evaluate_end2end.py \
31 | --input_file data/evaluate/orpo-llama-3-iter-3-planner-dr_spider_${test_set}.json \
32 | --output_file $eval_file \
33 | --model-name llama --mode test --n_return $n --temperature $temperature --api_host http://192.168.1.118:8003 --skip_planner --seed $seed --n_processes 32 --skip_validator_join --skip_selection
34 |
35 | rm progress_dr.pkl
36 | python -u check_correct_recall.py --pred_file $eval_file --progress_file progress_dr.pkl > logs/dr-spider/log-orpo-llama-3-iter-3-end2end-dr_spider-$test_set-n$n-temp$temperature-seed$seed.txt
37 | mv results_spider.pkl logs/results-dr-spider-end2end-$test_set.pkl
38 | done
39 |
40 |
--------------------------------------------------------------------------------
/visualization/domain_knowledge.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 |
5 | # Sample dataset (Ensure to replace with actual data)
6 | data = {
7 | "Method": ["MATS (Ours)", "DAILSQL(SC)", "CodeS-15B", "CodeS-7B", "REDSQL-3B\n+NatSQL", "REDSQL-3B", "Graphix\n+PICARD"],
8 | "College": [84.0, 79.6, 82.4, 83.3, 80.6, 83.3, 78.7],
9 | "Competition": [92.0, 79.0, 85.5, 82.3, 80.6, 83.9, 82.3],
10 | "Geography": [71.0, 76.7, 75.0, 75.8, 52.5, 65.0, 64.2],
11 | "Social": [95.0, 83.9, 83.9, 82.1, 76.8, 80.4, 82.1],
12 | "Transportation": [97.0, 85.0, 88.8, 87.5, 86.3, 80.0, 98.8],
13 | "Overall": [87.1, 83.6, 84.9, 85.4, 84.1, 81.8, 80.9]
14 | }
15 |
16 | # DB Count Data (Ensure to replace with actual data)
17 | db_count = {"College": 10, "Competition": 5, "Geography": 3, "Social": 2, "Transportation": 12}
18 |
19 | # Convert to DataFrame
20 | df = pd.DataFrame(data)
21 | df.set_index("Method", inplace=True)
22 |
23 | # Transpose DataFrame to swap axes
24 | df = df.T
25 |
26 | # Create a figure with two side-by-side subplots, adjusting colors and layout
27 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3.5), gridspec_kw={'width_ratios': [4, 0.8]})
28 |
29 | # Create the heatmap in the first subplot with a colormap similar to the reference image
30 | sns.heatmap(df, annot=True, cmap="YlGnBu", linewidths=0.5, fmt=".1f", cbar=False, ax=axes[0])
31 | axes[0].set_xlabel("", fontsize=8)
32 | axes[0].set_ylabel("DB Domain", fontsize=8)
33 | axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=90, ha="right", fontsize=6)
34 | axes[0].set_yticklabels(axes[0].get_yticklabels(), fontsize=6)
35 |
36 | # Create the DB count bar plot in the second subplot
37 | domains = list(db_count.keys())
38 | db_values = list(db_count.values())
39 | axes[1].barh(domains, db_values, color="#1f77b4", alpha=0.8) # Adjusted to a similar blue tone
40 | axes[1].set_xlabel("#DB Count", fontsize=6) # Reduce x-axis title size
41 | axes[1].set_yticklabels([]) # Remove y-axis ticks
42 | axes[1].set_xticks(range(0, max(db_values) + 1, max(2, max(db_values) // 4))) # Keep sparse x-axis ticks
43 | axes[1].tick_params(axis='x', labelsize=6) # Reduce x-axis tick label size
44 |
45 | # Adjust layout for better fitting
46 | plt.tight_layout()
47 |
48 | # Show the plot
49 | plt.show()
--------------------------------------------------------------------------------
/visualization/rlef_improvement.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | # Data for Spider Dev and BIRD Dev
4 | steps = [0, 1, 2, 3] # RLEF iterations
5 | labels = ['SFT', 'Iter 1', 'Iter 2', 'Iter 3']
6 |
7 | # MATS 3B Planner Data
8 | spider_3b = [83.3, 84, 85.4, 85.2]
9 | bird_3b = [53.65, 56.32, 59.32, 58.6]
10 |
11 | # MATS Data
12 | spider_mats = [85.5, 86.3, 87.1, 87]
13 | bird_mats = [59.06, 60.82, 64.73, 62.58]
14 |
15 | # Reference model performance for horizontal lines
16 | ref_lines = {
17 | "GPT-4": {"color": "cyan", "style": "-.", "spider": 76.8, "bird": 49.15},
18 | "CodeS-15B": {"color": "red", "style": ":", "spider": 84.9, "bird": 58.47},
19 | "MAC-SQL + GPT-4": {"color": "purple", "style": "--", "spider": 86.75, "bird": 59.59},
20 | "DIN-SQL + GPT-4": {"color": "green", "style": "--", "spider": 82.8, "bird": 50.72}
21 | }
22 |
23 | # Adjust figure size for a single-column layout
24 | fig, axes = plt.subplots(2, 1, figsize=(3.5, 5), sharex=True)
25 |
26 | # Spider Dev subplot
27 | axes[0].plot(steps, spider_3b, marker='o', color='blue', label='MATS 3B Planner', linewidth=1.5)
28 | axes[0].plot(steps, spider_mats, marker='s', color='red', label='MATS', linewidth=1.5)
29 | for label, data in ref_lines.items():
30 | axes[0].axhline(y=data["spider"], color=data["color"], linestyle=data["style"], label=label, linewidth=1)
31 | axes[0].set_title('Spider Dev', fontsize=8, pad=8)
32 | axes[0].grid(alpha=0.3)
33 |
34 | # BIRD Dev subplot
35 | axes[1].plot(steps, bird_3b, marker='o', color='blue', linewidth=1.5)
36 | axes[1].plot(steps, bird_mats, marker='s', color='red', linewidth=1.5)
37 | for label, data in ref_lines.items():
38 | axes[1].axhline(y=data["bird"], color=data["color"], linestyle=data["style"], linewidth=1)
39 | axes[1].set_title('BIRD Dev', fontsize=8, pad=8)
40 | axes[1].set_xticks(steps)
41 | axes[1].set_xticklabels(labels, rotation=45, fontsize=8)
42 | axes[1].grid(alpha=0.3)
43 |
44 | # Remove axis titles (labels)
45 | axes[0].set_ylabel('')
46 | axes[1].set_ylabel('')
47 | axes[1].set_xlabel('')
48 |
49 | # Add a single legend outside the subplots
50 | handles, labels = axes[0].get_legend_handles_labels()
51 | fig.legend(handles, labels, loc='lower center', fontsize=7, ncol=2, frameon=False, bbox_to_anchor=(0.5, -0.09))
52 |
53 | # Adjust layout for compactness
54 | plt.tight_layout(pad=1.0)
55 | plt.show()
56 |
--------------------------------------------------------------------------------
/visualization/.ipynb_checkpoints/rlef_improvement-checkpoint.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | # Data for Spider Dev and BIRD Dev
4 | steps = [0, 1, 2, 3] # RLEF iterations
5 | labels = ['SFT', 'Iter 1', 'Iter 2', 'Iter 3']
6 |
7 | # MATS 3B Planner Data
8 | spider_3b = [83.3, 84, 85.4, 85.2]
9 | bird_3b = [53.65, 56.32, 59.32, 58.6]
10 |
11 | # MATS Data
12 | spider_mats = [85.5, 86.3, 87.1, 87]
13 | bird_mats = [59.06, 60.82, 64.73, 62.58]
14 |
15 | # Reference model performance for horizontal lines
16 | ref_lines = {
17 | "GPT-4": {"color": "cyan", "style": "-.", "spider": 76.8, "bird": 49.15},
18 | "CodeS-15B": {"color": "red", "style": ":", "spider": 84.9, "bird": 58.47},
19 | "MAC-SQL + GPT-4": {"color": "purple", "style": "--", "spider": 86.75, "bird": 59.59},
20 | "DIN-SQL + GPT-4": {"color": "green", "style": "--", "spider": 82.8, "bird": 50.72}
21 | }
22 |
23 | # Adjust figure size for a single-column layout
24 | fig, axes = plt.subplots(2, 1, figsize=(3.5, 5), sharex=True)
25 |
26 | # Spider Dev subplot
27 | axes[0].plot(steps, spider_3b, marker='o', color='blue', label='MATS 3B Planner', linewidth=1.5)
28 | axes[0].plot(steps, spider_mats, marker='s', color='red', label='MATS', linewidth=1.5)
29 | for label, data in ref_lines.items():
30 | axes[0].axhline(y=data["spider"], color=data["color"], linestyle=data["style"], label=label, linewidth=1)
31 | axes[0].set_title('Spider Dev', fontsize=10, pad=8)
32 | axes[0].grid(alpha=0.3)
33 |
34 | # BIRD Dev subplot
35 | axes[1].plot(steps, bird_3b, marker='o', color='blue', linewidth=1.5)
36 | axes[1].plot(steps, bird_mats, marker='s', color='red', linewidth=1.5)
37 | for label, data in ref_lines.items():
38 | axes[1].axhline(y=data["bird"], color=data["color"], linestyle=data["style"], linewidth=1)
39 | axes[1].set_title('BIRD Dev', fontsize=10, pad=8)
40 | axes[1].set_xticks(steps)
41 | axes[1].set_xticklabels(labels, rotation=45, fontsize=8)
42 | axes[1].grid(alpha=0.3)
43 |
44 | # Remove axis titles (labels)
45 | axes[0].set_ylabel('')
46 | axes[1].set_ylabel('')
47 | axes[1].set_xlabel('')
48 |
49 | # Add a single legend outside the subplots
50 | handles, labels = axes[0].get_legend_handles_labels()
51 | fig.legend(handles, labels, loc='lower center', fontsize=7, ncol=2, frameon=False, bbox_to_anchor=(0.5, -0.1))
52 |
53 | # Adjust layout for compactness
54 | plt.tight_layout(pad=1.0)
55 | plt.show()
56 |
--------------------------------------------------------------------------------
/validator_data/generate_validator_order_using_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from tqdm import tqdm
4 | from validator_data.validator import ValidatorOrder, _execute_sql, _make_str_response, is_execution_correct
5 | import re
6 | import argparse
7 |
8 | # add parse for input data file (train, dev) and output_file
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json')
11 | parser.add_argument('--output_file', type=str, default='bird_validator_select.jsonl')
12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai'])
13 | args = parser.parse_args()
14 |
15 | data = []
16 | with open(args.input_file) as fp:
17 | for line in fp:
18 | data.append(json.loads(line))
19 |
20 | # load saved output file if exists
21 | if os.path.exists(args.output_file):
22 | old_output = []
23 | with open(args.output_file, 'r') as f:
24 | for line in f:
25 | old_output.append(json.loads(line))
26 | data[:len(old_output)] = old_output
27 | else:
28 | old_output = []
29 |
30 | # open jsonl file for append contents
31 | output_file = open(args.output_file, 'a+')
32 |
33 | validator = ValidatorOrder(endpoint_type=args.endpoint_type)
34 |
35 | for isample in tqdm(range(len(old_output), len(data)), total=len(data) - len(old_output)):
36 | sample = data[isample]
37 |
38 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql'])
39 |
40 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planner'], re.DOTALL)
41 | if pred_sql_match is None: continue
42 |
43 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip()
44 | sample['predict_sql'] = pred_sql
45 |
46 | answer, execution_result = validator.validate(sample)
47 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0])
48 |
49 | print("-"*20)
50 | print("Is correct: ", is_correct)
51 | print(answer)
52 |
53 | sample['is_correct'] = is_correct
54 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer
55 | sample['validator_order'] = answer
56 |
57 | sample['true_result'] = _make_str_response(*true_execution_result)
58 | sample['pred_result'] = _make_str_response(*execution_result)
59 |
60 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
61 |
--------------------------------------------------------------------------------
/bird_evaluation/run_evaluation.sh:
--------------------------------------------------------------------------------
1 | db_root_path='data/bird-062024/dev/dev_databases/'
2 | data_mode='dev'
3 | diff_json_path='data/bird-062024/dev/dev.json'
4 | predicted_sql_path=$1
5 | ground_truth_path='data/bird-062024/dev/'
6 | num_cpus=16
7 | meta_time_out=30.0
8 | mode_gt='gt'
9 | mode_predict='gpt'
10 |
11 | # db_root_path='./data/sft_data_collections/bird/dev/dev_databases/'
12 | # data_mode='dev'
13 | # diff_json_path='./data/sft_data_collections/bird/dev/dev.json'
14 | # # predicted_sql_path=$1
15 | # ground_truth_path='./data/sft_data_collections/bird/dev/'
16 | # num_cpus=16
17 | # meta_time_out=30.0
18 | # mode_gt='gt'
19 | # mode_predict='gpt'
20 |
21 | echo '''starting to compare with knowledge for ex'''
22 | python3 -u ./bird_evaluation/evaluation.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} \
23 | --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \
24 | --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out}
25 |
26 | echo '''starting to compare with knowledge for ves'''
27 | python3 -u ./bird_evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out}
28 |
29 |
30 |
31 | # db_root_path='./data/dev/dev_databases/'
32 | # data_mode='dev'
33 | # diff_json_path='./data/dev/dev.json'
34 | # predicted_sql_path=$1
35 | # ground_truth_path='./data/dev/'
36 | # num_cpus=16
37 | # meta_time_out=30.0
38 | # mode_gt='gt'
39 | # mode_predict='gpt'
40 |
41 | # echo '''starting to compare with knowledge for ex'''
42 | # python3 -u ./bird_evaluation/evaluation_062024.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} --data_mode ${data_mode} \
43 | # --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} \
44 | # --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out}
45 |
46 | # echo '''starting to compare with knowledge for ves'''
47 | # python3 -u ./bird_evaluation/evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path $1 --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out}
48 |
--------------------------------------------------------------------------------
/utils/load_classifier_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import itertools
3 | from torch.utils.data import Dataset
4 |
5 | class SchemaItemClassifierDataset(Dataset):
6 | def __init__(self, dataset_dir):
7 | super(SchemaItemClassifierDataset, self).__init__()
8 |
9 | self.texts: list[str] = []
10 | self.all_column_names: list[list[list[str]]] = []
11 | self.all_column_labels: list[list[list[int]]] = []
12 | self.all_table_names: list[list[str]] = []
13 | self.all_table_labels: list[list[int]] = []
14 |
15 | dataset = json.load(open(dataset_dir))
16 |
17 | assert type(dataset) == list
18 |
19 | for data in dataset:
20 | table_names_in_one_db = []
21 | column_names_in_one_db = []
22 |
23 | for table in data["schema"]["schema_items"]:
24 | # table_names_in_one_db.append(table["table_name"])
25 | # column_names_in_one_db.append(table["column_names"])
26 | table_names_in_one_db.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
27 | if table["table_comment"] != "" else table["table_name"])
28 | column_names_in_one_db.append([column_name + " ( " + column_comment + " ) " \
29 | if column_comment != "" else column_name \
30 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
31 |
32 | self.texts.append(data["text"])
33 | self.all_table_names.append(table_names_in_one_db)
34 | self.all_column_names.append(column_names_in_one_db)
35 | self.all_table_labels.append(data["table_labels"])
36 | self.all_column_labels.append(list(itertools.chain(*data["column_labels"])))
37 |
38 | def __len__(self):
39 | return len(self.texts)
40 |
41 | def __getitem__(self, index):
42 | text = self.texts[index]
43 | table_names_in_one_db = self.all_table_names[index]
44 | table_labels_in_one_db = self.all_table_labels[index]
45 | column_infos_in_one_db = self.all_column_names[index]
46 | column_labels_in_one_db = self.all_column_labels[index]
47 |
48 | return {
49 | "text": text,
50 | "table_names_in_one_db": table_names_in_one_db,
51 | "table_labels_in_one_db": table_labels_in_one_db,
52 | "column_infos_in_one_db": column_infos_in_one_db,
53 | "column_labels_in_one_db": column_labels_in_one_db
54 | }
55 |
--------------------------------------------------------------------------------
/data_processing/prompts/zero_shot_prompt_planner.txt:
--------------------------------------------------------------------------------
1 | database schema :
2 | table pages , columns = [ pages.words ( integer | values : 1081 , 68 ) , pages.page ( integer | values : 1 , 2 ) , pages.pid ( integer | primary key | comment : page id | values : 1 , 2 ) , pages.title ( text | values : Àbac , Abadia ) , pages.lid ( integer | comment : language id | values : 1 ) , pages.revision ( integer | values : 28236978 , 24086480 ) ]
3 | table words , columns = [ words.word ( text | values : +,2 , +,33 ) , words.wid ( integer | primary key | comment : word id | values : 2148990 , 2506463 ) , words.occurrences ( integer | values : 242 , 16841 ) ]
4 | table langs , columns = [ langs.pages ( integer | values : 1129144 ) , langs.words ( integer | values : 2764996 ) , langs.lid ( integer | primary key | comment : language id | values : 1 ) , langs.lang ( text | comment : language | values : ca ) , langs.locale ( text | values : ca_ES ) ]
5 | table pages_words , columns = [ pages_words.pid ( integer | primary key | comment : page id | values : 1 , 2 ) , pages_words.wid ( integer | primary key | comment : word id | values : 1 , 2 ) , pages_words.occurrences ( integer | values : 30 , 8 ) ]
6 | table langs_words , columns = [ langs_words.wid ( integer | primary key | comment : word id | values : 1 , 2 ) , langs_words.occurrences ( integer | values : 242 , 16841 ) , langs_words.lid ( integer | primary key | comment : language id | values : 1 ) ]
7 | table biwords , columns = [ biwords.occurrences ( integer | values : 4 , 3 ) , biwords.lid ( integer | primary key | comment : language id | values : 1 ) , biwords.w1st ( integer | primary key | comment : word id of the first word | values : 1 , 2 ) , biwords.w2nd ( integer | primary key | comment : word id of the second word | values : 2 , 4 ) ]
8 | foreign keys :
9 | pages.lid = langs.lid
10 | langs_words.wid = words.wid
11 | langs_words.lid = langs.lid
12 | pages_words.wid = words.wid
13 | pages_words.pid = pages.pid
14 | biwords.w2nd = words.wid
15 | biwords.w1st = words.wid
16 | biwords.lid = langs.lid
17 |
18 | matched contents :
19 | pages.words ( 1500 )
20 | pages.page ( 1500 )
21 | pages.pid ( 1500 )
22 | pages.title ( Pages , 1500 )
23 | words.word ( pages , words , calculates , differents , divides , percentages , counts , page's , wordes )
24 | pages_words.occurrences ( 1500 )
25 | langs_words.wid ( 1500 )
26 | langs_words.occurrences ( 1500 )
27 | biwords.occurrences ( 1500 )
28 | biwords.w1st ( 1500 )
29 | biwords.w2nd ( 1500 )
30 |
31 | Question: DIVIDE(COUNT(pages WHERE words = 1500), COUNT(pages)) as percentage; Calculate the percentage of pages that have 1500 different words.
32 |
33 | Planning:
--------------------------------------------------------------------------------
/validator_data/generate_validator_select_using_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from tqdm import tqdm
4 | from validator_data.validator import ValidatorSelect, _execute_sql, _make_str_response, is_execution_correct
5 | import argparse
6 | import re
7 |
8 | # add parse for input data file (train, dev) and output_file
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--input_file', type=str, default='./data/evaluate/phi-3-planner_combine_bird_062024_with_evidence_train.jsonl')
11 | parser.add_argument('--output_file', type=str, default='bird_validator_select.jsonl')
12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp',
13 | choices=['vllm', 'llamacpp', 'openai'])
14 | args = parser.parse_args()
15 |
16 | data = []
17 | with open(args.input_file) as fp:
18 | for line in fp:
19 | data.append(json.loads(line))
20 |
21 | # load saved output file if exists
22 | if os.path.exists(args.output_file):
23 | old_output = []
24 | with open(args.output_file, 'r') as f:
25 | for line in f:
26 | old_output.append(json.loads(line))
27 | data[:len(old_output)] = old_output
28 | else:
29 | old_output = []
30 |
31 | # open jsonl file for append contents
32 | output_file = open(args.output_file, 'a+')
33 |
34 | validator = ValidatorSelect(endpoint_type=args.endpoint_type)
35 |
36 | for isample in tqdm(range(len(old_output), len(data)), total=len(data)-len(old_output)):
37 | sample = data[isample]
38 |
39 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql'])
40 |
41 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planner'], re.DOTALL)
42 | if pred_sql_match is None: continue
43 |
44 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip()
45 | sample['predict_sql'] = pred_sql
46 |
47 |
48 | answer, execution_result = validator.validate(sample)
49 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0])
50 |
51 | print("-"*20)
52 | print("Is correct: ", is_correct)
53 | print(answer)
54 |
55 | sample['is_correct'] = is_correct
56 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer
57 | sample['validator_select'] = answer
58 |
59 | sample['true_result'] = _make_str_response(*true_execution_result)
60 | sample['pred_result'] = _make_str_response(*execution_result)
61 |
62 | # json.dump(data[:isample+1], open(args.output_file, 'w+'), ensure_ascii=False, indent=4)
63 | # write new sample in jsonl file
64 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
65 |
--------------------------------------------------------------------------------
/validator_data/generate_fixed_sql_using_fewshot_condition.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sqlite3
3 | import os
4 | import multiprocessing.pool
5 | import functools
6 | from tqdm import tqdm
7 | import pandas as pd
8 | from validator import FixAgent, _make_str_response, _execute_sql, is_execution_correct
9 |
10 | PROMPT = open('./few_shot_prompt_fix_join.txt').read().strip() + """
11 | =========
12 | {schema}
13 |
14 | Matched contents are written in this format table.column (some values can be found in that column)
15 | {matched_content}
16 |
17 | Question: {question}
18 |
19 | SQL query: {sql_query}
20 |
21 | Feedback:{feedback}
22 |
23 | FIXED SQL:"""
24 |
25 | data = []
26 | with open('../data/llm_alignment/validator_condition_bird_with_evidence_dev.jsonl') as fp:
27 | for line in fp:
28 | data.append(json.loads(line))
29 |
30 | output_file = './bird_fixed_sql_condition.jsonl'
31 |
32 | output_fp = open(output_file, 'w+')
33 |
34 | fix_agent = FixAgent(PROMPT, endpoint_type='vllm')
35 |
36 | for isample in tqdm(range(len(data)), total=len(data)):
37 | sample = data[isample]
38 |
39 | is_correct = sample['is_correct']
40 | if sample['validator_condition'] is None or "Conclude: correct" in sample['validator_condition']:
41 | output_fp.write(json.dumps(sample) + '\n')
42 | continue
43 |
44 | prompt = PROMPT.format(
45 | schema=sample['schema_sequence'],
46 | matched_content=sample['content_sequence'],
47 | question=sample['text'],
48 | sql_query=sample['predict_sql'],
49 | # execution_response=sample['pred_result'],
50 | feedback=sample['validator_condition']
51 | )
52 | # print(prompt)
53 | answer = fix_agent.get_answer([{"role": "user", "content": prompt}])
54 | execution_result = _execute_sql("../" + sample['db_path'], answer)
55 |
56 | print("-"*20)
57 | print(answer)
58 | # break
59 | sample['fixed_sql'] = answer
60 | sample['fixed_pred_result'] = _make_str_response(*execution_result)
61 |
62 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql'])
63 | is_fixed_correct = is_execution_correct(true_execution_result[0], execution_result[0])
64 | sample['is_fixed_correct'] = is_fixed_correct
65 |
66 | output_fp.write(json.dumps(sample) + '\n')
67 |
68 | bird_results_dict = dict()
69 | for idx, sample in enumerate(data):
70 | if 'fixed_sql' in sample:
71 | predicted_sql = sample['fixed_sql']
72 | else:
73 | predicted_sql = sample['predict_sql']
74 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"]
75 | with open("predict_dev.json", "w", encoding = 'utf-8') as f:
76 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False))
77 | output_fp.close()
78 |
79 |
--------------------------------------------------------------------------------
/visualization/lambda_sensitivity.py:
--------------------------------------------------------------------------------
1 | # import matplotlib.pyplot as plt
2 |
3 | # # Data
4 | # gamma_values = [0, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 1]
5 | # bird_dev_values = [54.95, 55.54, 56, 55.54, 56.39, 56.19, 57.11, 57.24, 57.24, 56.52, 56.32]
6 | # count_values = [14, 14, 16, 15, 14, 12, 10, 5, 8, 6, 8]
7 |
8 | # # Adjust figure size for single-column fit in a 2-column research paper
9 | # fig_width = 3.5 # Typical width for a single column in inches
10 | # fig_height = 2.5 # Adjusted height for readability
11 |
12 | # # Create figure and axis objects with adjusted size
13 | # fig, ax1 = plt.subplots(figsize=(fig_width, fig_height))
14 |
15 | # # First Y-axis (EX%)
16 | # ax1.set_xlabel(r"$\lambda$", fontsize=10) # Using LaTeX formatting for lambda
17 | # ax1.set_ylabel("EX%", color="tab:blue", fontsize=10)
18 | # ax1.plot(gamma_values, bird_dev_values, marker="o", linestyle="-", color="tab:blue")
19 | # ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=8)
20 | # ax1.tick_params(axis="x", labelsize=8)
21 |
22 | # # Second Y-axis (No. Syntax Error)
23 | # ax2 = ax1.twinx()
24 | # ax2.set_ylabel("No. Syntax Error", color="tab:red", fontsize=10)
25 | # ax2.plot(gamma_values, count_values, marker="s", linestyle="--", color="tab:red")
26 | # ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=8)
27 |
28 | # # Grid
29 | # ax1.grid(True, linestyle="--", alpha=0.6)
30 |
31 | # # Adjust layout for better spacing
32 | # plt.tight_layout()
33 |
34 | # # Show plot
35 | # plt.show()
36 |
37 |
38 | import matplotlib.pyplot as plt
39 |
40 | # Data
41 | gamma_values = [0, 0.05, 0.1, 0.25, 0.5, 0.75, 1]
42 | bird_dev_values = [54.95, 56.19, 57.11, 57.24, 57.24, 56.52, 56.32]
43 | count_values = [14, 12, 10, 5, 8, 6, 8]
44 |
45 | # Adjust figure size for single-column fit in a 2-column research paper
46 | fig_width = 3.5 # Typical width for a single column in inches
47 | fig_height = 2.5 # Adjusted height for readability
48 |
49 | # Create figure and axis objects with adjusted size
50 | fig, ax1 = plt.subplots(figsize=(fig_width, fig_height))
51 |
52 | # First Y-axis (EX%)
53 | ax1.set_xlabel(r"$\lambda$", fontsize=10) # Using LaTeX formatting for lambda
54 | ax1.set_ylabel("EX%", color="tab:blue", fontsize=10)
55 | ax1.plot(gamma_values, bird_dev_values, marker="o", linestyle="-", color="tab:blue")
56 | ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=8)
57 | ax1.tick_params(axis="x", labelsize=8)
58 |
59 | # Second Y-axis (No. Syntax Error)
60 | ax2 = ax1.twinx()
61 | ax2.set_ylabel("No. Syntax Error", color="tab:red", fontsize=10)
62 | ax2.plot(gamma_values, count_values, marker="s", linestyle="--", color="tab:red")
63 | ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=8)
64 |
65 | # Grid
66 | ax1.grid(True, linestyle="--", alpha=0.6)
67 |
68 | # Adjust layout for better spacing
69 | plt.tight_layout()
70 |
71 | # Show plot
72 | plt.show()
73 |
--------------------------------------------------------------------------------
/validator_data/generate_validator_join_using_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from validator_data.validator import ValidatorJOIN, _execute_sql, _make_str_response, is_execution_correct, ValidatorJOINWithTrueSQL
4 | import argparse
5 | import os
6 | import re
7 | from multiprocessing import Pool
8 |
9 | def process_sample(sample):
10 | """Process a single sample."""
11 |
12 | validator = Validator(endpoint_type=args.endpoint_type)
13 |
14 | try:
15 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql'])
16 |
17 | sample['predict_sql'] = sample['predict_sqls'][0]
18 | prompt, answer, execution_result = validator.validate(sample)
19 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0])
20 |
21 | sample['prompt_validator_join'] = prompt
22 | sample['is_correct'] = is_correct
23 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer
24 | sample['validator_join'] = answer
25 | sample['true_result'] = _make_str_response(*true_execution_result)
26 | sample['pred_result'] = _make_str_response(*execution_result)
27 |
28 | return sample
29 | except Exception as e:
30 | print(f"Error processing sample: {e}")
31 | return None
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json')
36 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl')
37 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai'])
38 | parser.add_argument('--use_hidden_sql', action='store_true')
39 | args = parser.parse_args()
40 |
41 | if args.use_hidden_sql:
42 | Validator = ValidatorJOINWithTrueSQL
43 | else:
44 | Validator = ValidatorJOIN
45 |
46 | data = []
47 | with open(args.input_file) as fp:
48 | for line in fp:
49 | data.append(json.loads(line))
50 |
51 | # Load saved output file if exists
52 | processed_keys = set()
53 | if os.path.exists(args.output_file):
54 | with open(args.output_file, 'r') as f:
55 | for line in f:
56 | sample = json.loads(line)
57 | processed_keys.add((sample['db_id'], sample['question']))
58 |
59 | # Determine samples to process
60 | samples_to_process = [sample for sample in data if (sample['db_id'], sample['question']) not in processed_keys]
61 |
62 | # Open output file in append mode
63 | with open(args.output_file, 'a') as output_file:
64 | with Pool(8) as pool:
65 | for result in tqdm(pool.imap(process_sample, samples_to_process), total=len(samples_to_process)):
66 | if result is not None:
67 | output_file.write(json.dumps(result, ensure_ascii=False) + '\n')
68 |
--------------------------------------------------------------------------------
/validator_data/generate_fixed_sql_using_fewshot_join.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import pandas as pd
4 | import argparse
5 | from validator import FixAgent, _execute_sql, _make_str_response, is_execution_correct
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--input_file', type=str, default='../data/llm_alignment/validator_join_bird_with_evidence_train.jsonl')
9 | parser.add_argument('--output_file', type=str, default='../data/llm_alignment/fixed_join_bird_with_evidence_train.jsonl')
10 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp'])
11 | args = parser.parse_args()
12 |
13 | PROMPT = open('./few_shot_prompt_fix_join.txt').read().strip() + """
14 | =========
15 | {schema}
16 |
17 | Matched contents are written in this format table.column (some values can be found in that column)
18 | {matched_content}
19 |
20 | Question: {question}
21 |
22 | SQL query: {sql_query}
23 |
24 | Feedback:{feedback}
25 |
26 | FIXED SQL:"""
27 |
28 | # load data from jsonl
29 | data = []
30 | with open(args.input_file, 'r') as f:
31 | for line in f:
32 | data.append(json.loads(line))
33 |
34 | fix_agent = FixAgent(prompt_template=PROMPT, endpoint_type=args.endpoint_type)
35 |
36 | output_file = open(args.output_file, 'a+')
37 |
38 | for isample in tqdm(range(0, len(data)), total=len(data)):
39 | sample = data[isample]
40 |
41 | sql = sample['predict_sql']
42 | is_correct = sample['is_correct']
43 | if sample['validator_join'] is None or "Conclude: correct" in sample['validator_join']:
44 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
45 | continue
46 |
47 | prompt = PROMPT.format(
48 | schema=sample['schema_sequence'],
49 | matched_content=sample['content_sequence'],
50 | question=sample['text'],
51 | sql_query=sql,
52 | feedback=sample['validator_join']
53 | )
54 | # print(prompt)
55 | answer = fix_agent.get_answer([{"role": "user", "content": prompt}])
56 |
57 | execution_result = _execute_sql("../" + sample['db_path'], answer)
58 |
59 | print("-"*20)
60 | print(answer)
61 | # break
62 | sample['fixed_sql'] = answer
63 | sample['fixed_pred_result'] = _make_str_response(*execution_result)
64 |
65 | true_execution_result = _execute_sql("../" + sample['db_path'], sample['sql'])
66 | is_fixed_correct = is_execution_correct(true_execution_result[0], execution_result[0])
67 | sample['is_fixed_correct'] = is_fixed_correct
68 |
69 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
70 |
71 | bird_results_dict = dict()
72 | for idx, sample in enumerate(data):
73 | if 'fixed_sql' in sample:
74 | predicted_sql = sample['fixed_sql']
75 | else:
76 | predicted_sql = sample['predict_sql']
77 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"]
78 | with open("predict_dev.json", "w", encoding = 'utf-8') as f:
79 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False))
80 |
--------------------------------------------------------------------------------
/data_processing/generate_planner_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from tqdm import tqdm
4 | from planner import PlannerCombine, PlannerCombineWithTrueSQL
5 | import argparse
6 | from multiprocessing import Pool
7 |
8 | # add parse for input data file (train, dev) and output_file
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--input_file', type=str, default='../data/sft_bird_with_evidence_train_text2sql.json')
11 | parser.add_argument('--output_file', type=str, default='../data/planner/planner_select_bird_with_evidence_train.jsonl')
12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai'])
13 | parser.add_argument('--mode', type=str, choices=['select', 'condition', 'combine', 'combine_with_true_sql'], default='combine')
14 | parser.add_argument('--prompt', choices=['few-shot', 'cot'], default='few-shot')
15 | args = parser.parse_args()
16 |
17 | if args.input_file.endswith('.json'):
18 | data = json.load(open(args.input_file))
19 | elif args.input_file.endswith('.jsonl'):
20 | data = []
21 | with open(args.input_file, 'r') as f:
22 | for line in f:
23 | data.append(json.loads(line))
24 |
25 | # load saved output file if exists
26 | if os.path.exists(args.output_file):
27 | old_output = []
28 | with open(args.output_file, 'r') as f:
29 | for line in f:
30 | old_output.append(json.loads(line))
31 | data[:len(old_output)] = old_output
32 | else:
33 | old_output = []
34 |
35 | # open jsonl file for append contents
36 | # makedirs if not exists
37 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
38 | output_file = open(args.output_file, 'a+')
39 |
40 | if args.mode == 'combine':
41 | planner = PlannerCombine(endpoint_type=args.endpoint_type)
42 | elif args.mode == 'combine_with_true_sql':
43 | planner = PlannerCombineWithTrueSQL(endpoint_type=args.endpoint_type)
44 |
45 | if args.prompt == 'cot':
46 | planner.prompt_template = """{schema}
47 |
48 | Question: {question}
49 | External knowledge: {evidence}
50 |
51 | Use this hidden True SQL query to write correct analysis that derives to the correct answer. The True SQL query cannot be used in the analysis.
52 | Hidden True SQL query: {true_sql_query}
53 |
54 | Write your thought in short then write the final SQL query, answer in this format:
55 | [your short thought step-by-step]
56 | Final SQL query:
57 | ```
58 | [SQL query]
59 | ```
60 | """
61 |
62 | def process_sample(sample):
63 | answer = planner.generate(sample)
64 | sample[f'planner_{args.mode}'] = answer
65 | return sample
66 |
67 | def main():
68 | chunk_size = 4
69 | with open(args.output_file, 'a') as output_file:
70 | for i in tqdm(range(len(old_output), len(data), chunk_size), total=(len(data) - len(old_output))//chunk_size):
71 | chunk = data[i:i+chunk_size]
72 | pool = Pool(chunk_size)
73 | processed_samples = pool.map(process_sample, chunk)
74 | pool.close()
75 |
76 | if len(processed_samples) > 0:
77 | print(processed_samples[0][f'planner_{args.mode}'])
78 |
79 | for sample in processed_samples:
80 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
81 |
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/utils/load_pt_dataset.py:
--------------------------------------------------------------------------------
1 | # saves the SQL corpus to several binary files for pre-training CodeGen. following was helpful:
2 | # https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py
3 |
4 | import numpy as np
5 | import torch
6 | import random
7 | from torch.utils.data import IterableDataset, Dataset, DataLoader
8 |
9 | # class PretrainDataset(IterableDataset):
10 | # def __init__(self, pt_data_dir, block_size, epochs):
11 | # super().__init__()
12 | # self.corpus = np.memmap(pt_data_dir, dtype = np.uint16, mode = 'r')
13 | # self.block_size = block_size
14 | # self.epochs = epochs
15 | # self.length = len(self.corpus) // self.block_size
16 |
17 | # # return a tokenized sequence
18 | # def __iter__(self):
19 | # for _ in range(self.epochs):
20 | # start_idx_list = list(range(0, len(self.corpus), self.block_size))
21 | # # for each epoch, shuffle the order of sequences
22 | # random.shuffle(start_idx_list)
23 |
24 | # for start_idx in start_idx_list:
25 | # input_ids = self.corpus[start_idx: start_idx + self.block_size]
26 | # # skip the sequence whose length is not equal to `block_size`
27 | # if len(input_ids) != self.block_size:
28 | # continue
29 |
30 | # input_ids = torch.from_numpy(input_ids.astype(np.int64))
31 | # attention_mask = torch.ones(len(input_ids))
32 |
33 | # yield {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
34 |
35 | # # # return a tokenized sequence
36 | # # def __iter__(self):
37 | # # for _ in range(self.dataset_length):
38 | # # # randomly select a sequence of token ids from the tokenized corpus
39 | # # idx = random.randint(0, len(self.corpus) - self.block_size)
40 | # # input_ids = self.corpus[idx: idx + self.block_size]
41 | # # input_ids = torch.from_numpy(input_ids.astype(np.int64))
42 | # # attention_mask = torch.ones(len(input_ids))
43 |
44 | # # yield {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
45 |
46 | # def __len__(self):
47 | # return self.length
48 |
49 | class PretrainDataset(Dataset):
50 | def __init__(self, pt_data_dir, block_size):
51 | super().__init__()
52 | self.corpus = np.memmap(pt_data_dir, dtype = np.uint16, mode = 'r')
53 | self.block_size = block_size
54 | self.length = len(self.corpus) // self.block_size
55 |
56 | # return a list of token ids in the corpus
57 | def __getitem__(self, index):
58 | input_ids = self.corpus[index * self.block_size : (index + 1) * self.block_size]
59 |
60 | input_ids = torch.from_numpy(input_ids.astype(np.int64))
61 | attention_mask = torch.ones(len(input_ids))
62 |
63 | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
64 |
65 | def __len__(self):
66 | return self.length
67 |
68 | if __name__ == "__main__":
69 | dataset = PretrainDataset("./data/pt_corpus/starcoder_corpus.bin", 6144)
70 | dataloader = DataLoader(dataset, batch_size = 4, shuffle = False, drop_last = True)
71 | for batch in dataloader:
72 | print("-"*20)
73 | print(len(dataset))
--------------------------------------------------------------------------------
/alignment-handbook/tests/test_model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 |
17 | import torch
18 |
19 | from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
20 | from alignment.data import DEFAULT_CHAT_TEMPLATE
21 |
22 |
23 | class GetQuantizationConfigTest(unittest.TestCase):
24 | def test_4bit(self):
25 | model_args = ModelArguments(load_in_4bit=True)
26 | quantization_config = get_quantization_config(model_args)
27 | self.assertTrue(quantization_config.load_in_4bit)
28 | self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
29 | self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
30 | self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
31 |
32 | def test_8bit(self):
33 | model_args = ModelArguments(load_in_8bit=True)
34 | quantization_config = get_quantization_config(model_args)
35 | self.assertTrue(quantization_config.load_in_8bit)
36 |
37 | def test_no_quantization(self):
38 | model_args = ModelArguments()
39 | quantization_config = get_quantization_config(model_args)
40 | self.assertIsNone(quantization_config)
41 |
42 |
43 | class GetTokenizerTest(unittest.TestCase):
44 | def setUp(self) -> None:
45 | self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
46 |
47 | def test_right_truncation_side(self):
48 | tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right"))
49 | self.assertEqual(tokenizer.truncation_side, "right")
50 |
51 | def test_left_truncation_side(self):
52 | tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left"))
53 | self.assertEqual(tokenizer.truncation_side, "left")
54 |
55 | def test_default_chat_template(self):
56 | tokenizer = get_tokenizer(self.model_args, DataArguments())
57 | self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
58 |
59 | def test_chatml_chat_template(self):
60 | chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
61 | tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))
62 | self.assertEqual(tokenizer.chat_template, chat_template)
63 |
64 |
65 | class GetPeftConfigTest(unittest.TestCase):
66 | def test_peft_config(self):
67 | model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99)
68 | peft_config = get_peft_config(model_args)
69 | self.assertEqual(peft_config.r, 42)
70 | self.assertEqual(peft_config.lora_alpha, 0.66)
71 | self.assertEqual(peft_config.lora_dropout, 0.99)
72 |
73 | def test_no_peft_config(self):
74 | model_args = ModelArguments(use_peft=False)
75 | peft_config = get_peft_config(model_args)
76 | self.assertIsNone(peft_config)
77 |
--------------------------------------------------------------------------------
/data_processing/merge_val_fix_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 10,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Updated fixed SQL data saved to ../data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "import json\n",
18 | "from copy import deepcopy\n",
19 | "\n",
20 | "# File paths\n",
21 | "fixed_sql_bird_file = '../data/multi-agents/fixed/gpt-4o-mini-fixed-bird_with_evidence_train.jsonl'\n",
22 | "validator_select_file = '../data/multi-agents/validator/gpt-4o-mini-validator_select_bird_with_evidence_train.jsonl'\n",
23 | "validator_condition_file = '../data/multi-agents/validator/gpt-4o-mini-validator_condition_bird_with_evidence_train.jsonl'\n",
24 | "validator_join_file = '../data/multi-agents/validator/gpt-4o-mini-validator_join_bird_with_evidence_train.jsonl'\n",
25 | "\n",
26 | "# Function to load JSONL files\n",
27 | "def load_jsonl(file_path):\n",
28 | " data = []\n",
29 | " with open(file_path, 'r', encoding='utf-8') as file:\n",
30 | " for line in file:\n",
31 | " data.append(json.loads(line))\n",
32 | " return data\n",
33 | "\n",
34 | "# Load all datasets\n",
35 | "fixed_sql_bird_data = load_jsonl(fixed_sql_bird_file)\n",
36 | "validator_select_data = load_jsonl(validator_select_file)\n",
37 | "validator_condition_data = load_jsonl(validator_condition_file)\n",
38 | "validator_join_data = load_jsonl(validator_join_file)\n",
39 | "\n",
40 | "# Process and add valid samples\n",
41 | "for sample_select, sample_condition, sample_join in zip(validator_select_data, validator_condition_data, validator_join_data):\n",
42 | "\n",
43 | " # Extract correctness feedback\n",
44 | " select_correct = sample_select.get('feedback_conclude')\n",
45 | " condition_correct = sample_condition.get('feedback_conclude')\n",
46 | " join_correct = sample_join.get('feedback_conclude')\n",
47 | "\n",
48 | " # If all are correct, add a new sample to fixed_sql_bird_data\n",
49 | " if select_correct and condition_correct and join_correct:\n",
50 | " new_sample = deepcopy(sample_select)\n",
51 | " new_sample = {\n",
52 | " \"validator_select\": sample_select,\n",
53 | " \"validator_condition\": sample_condition['validator_condition'],\n",
54 | " \"validator_join\": sample_join['validator_join'],\n",
55 | " \"fixed_sql\": [\"None\"] # Empty list as per instructions\n",
56 | " }\n",
57 | " fixed_sql_bird_data.append(new_sample)\n",
58 | "\n",
59 | "# Save the updated fixed SQL data\n",
60 | "output_file = '../data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl'\n",
61 | "with open(output_file, 'w', encoding='utf-8') as file:\n",
62 | " for entry in fixed_sql_bird_data:\n",
63 | " file.write(json.dumps(entry, ensure_ascii=False) + '\\n')\n",
64 | "\n",
65 | "print(f\"Updated fixed SQL data saved to {output_file}\")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": []
74 | }
75 | ],
76 | "metadata": {
77 | "kernelspec": {
78 | "display_name": "handbook",
79 | "language": "python",
80 | "name": "python3"
81 | },
82 | "language_info": {
83 | "codemirror_mode": {
84 | "name": "ipython",
85 | "version": 3
86 | },
87 | "file_extension": ".py",
88 | "mimetype": "text/x-python",
89 | "name": "python",
90 | "nbconvert_exporter": "python",
91 | "pygments_lexer": "ipython3",
92 | "version": "3.10.14"
93 | }
94 | },
95 | "nbformat": 4,
96 | "nbformat_minor": 2
97 | }
98 |
--------------------------------------------------------------------------------
/validator_data/generate_validator_condition_using_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from validator import ValidatorCondition, _execute_sql, _make_str_response, is_execution_correct, ValidatorConditionWithTrueSQL
4 | import argparse
5 | import re
6 | import os
7 | from multiprocessing import Pool, Manager
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--input_file', type=str, default='../temp/codes/eval_codes-1b.json')
11 | parser.add_argument('--output_file', type=str, default='bird_validator_join.jsonl')
12 | parser.add_argument('--endpoint_type', type=str, default='llamacpp', choices=['vllm', 'llamacpp', 'openai'])
13 | parser.add_argument('--use_hidden_sql', action='store_true')
14 | args = parser.parse_args()
15 |
16 | def process_sample(args):
17 | idx, sample, endpoint_type, output_file_lock, output_file_path = args
18 |
19 | try:
20 | validator = Validator(endpoint_type=endpoint_type)
21 |
22 | true_execution_result = _execute_sql("./" + sample['db_path'], sample['sql'])
23 |
24 | pred_sql_match = re.search(r"(?<=Final SQL query:).*", sample['planners'][0], re.DOTALL)
25 | if pred_sql_match is None:
26 | return None
27 |
28 | pred_sql = pred_sql_match.group().replace("sql", "").replace("```", "").strip()
29 | sample['predict_sql'] = pred_sql
30 |
31 | prompt, answer, execution_result = validator.validate(sample)
32 | answer = answer[0]
33 | is_correct = is_execution_correct(true_execution_result[0], execution_result[0])
34 |
35 | sample['is_correct'] = is_correct
36 | sample['feedback_conclude'] = answer is not None and 'Conclude: correct' in answer
37 | sample['validator_condition'] = answer
38 | sample['true_result'] = _make_str_response(*true_execution_result)
39 | sample['pred_result'] = _make_str_response(*execution_result)
40 |
41 | # Write the result to the file
42 | with output_file_lock:
43 | with open(output_file_path, 'a') as output_file:
44 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
45 |
46 | return idx
47 |
48 | except Exception as e:
49 | print(f"Error processing sample {idx}: {e}")
50 | return None
51 |
52 | if __name__ == "__main__":
53 | if args.use_hidden_sql:
54 | Validator = ValidatorConditionWithTrueSQL
55 | else:
56 | Validator = ValidatorCondition
57 |
58 | # Load data
59 | data = []
60 | with open(args.input_file) as fp:
61 | for line in fp:
62 | data.append(json.loads(line))
63 |
64 | # Load old results
65 | if os.path.exists(args.output_file):
66 | processed_indices = set()
67 | with open(args.output_file, 'r') as f:
68 | for line in f:
69 | result = json.loads(line)
70 | processed_indices.add(f"{result['db_id']} {result['question']}")
71 | print(f"Loaded {len(processed_indices)} previously processed samples.")
72 | else:
73 | processed_indices = set()
74 |
75 | # Filter data to process only unprocessed samples
76 | unprocessed_data = [
77 | (i, sample, args.endpoint_type, None, args.output_file)
78 | for i, sample in enumerate(data)
79 | if f"{sample['db_id']} {sample['question']}" not in processed_indices
80 | ]
81 |
82 | # Set up multiprocessing
83 | with Manager() as manager:
84 | output_file_lock = manager.Lock()
85 |
86 | # Add output_file_lock to each task
87 | tasks = [
88 | (idx, sample, args.endpoint_type, output_file_lock, args.output_file)
89 | for idx, sample, _, _, _ in unprocessed_data
90 | ]
91 |
92 | with Pool(processes=12) as pool:
93 | # Use tqdm to monitor progress
94 | for _ in tqdm(pool.imap_unordered(process_sample, tasks), total=len(tasks)):
95 | pass
96 |
97 | print("Processing completed.")
98 |
--------------------------------------------------------------------------------
/alignment-handbook/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # Temp folders
163 | data/
164 | wandb/
165 | output/
166 |
--------------------------------------------------------------------------------
/alignment-handbook/src/alignment/release.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import re
18 |
19 | import packaging.version
20 |
21 |
22 | REPLACE_PATTERNS = {
23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
25 | }
26 | REPLACE_FILES = {
27 | "init": "src/alignment/__init__.py",
28 | "setup": "setup.py",
29 | }
30 | README_FILE = "README.md"
31 |
32 |
33 | def update_version_in_file(fname, version, pattern):
34 | """Update the version in one file using a specific pattern."""
35 | with open(fname, "r", encoding="utf-8", newline="\n") as f:
36 | code = f.read()
37 | re_pattern, replace = REPLACE_PATTERNS[pattern]
38 | replace = replace.replace("VERSION", version)
39 | code = re_pattern.sub(replace, code)
40 | with open(fname, "w", encoding="utf-8", newline="\n") as f:
41 | f.write(code)
42 |
43 |
44 | def global_version_update(version, patch=False):
45 | """Update the version in all needed files."""
46 | for pattern, fname in REPLACE_FILES.items():
47 | update_version_in_file(fname, version, pattern)
48 |
49 |
50 | def get_version():
51 | """Reads the current version in the __init__."""
52 | with open(REPLACE_FILES["init"], "r") as f:
53 | code = f.read()
54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
55 | return packaging.version.parse(default_version)
56 |
57 |
58 | def pre_release_work(patch=False):
59 | """Do all the necessary pre-release steps."""
60 | # First let's get the default version: base version if we are in dev, bump minor otherwise.
61 | default_version = get_version()
62 | if patch and default_version.is_devrelease:
63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
64 | if default_version.is_devrelease:
65 | default_version = default_version.base_version
66 | elif patch:
67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
68 | else:
69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0"
70 |
71 | # Now let's ask nicely if that's the right one.
72 | version = input(f"Which version are you releasing? [{default_version}]")
73 | if len(version) == 0:
74 | version = default_version
75 |
76 | print(f"Updating version to {version}.")
77 | global_version_update(version, patch=patch)
78 |
79 |
80 | def post_release_work():
81 | """Do all the necessary post-release steps."""
82 | # First let's get the current version
83 | current_version = get_version()
84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
85 | current_version = current_version.base_version
86 |
87 | # Check with the user we got that right.
88 | version = input(f"Which version are we developing now? [{dev_version}]")
89 | if len(version) == 0:
90 | version = dev_version
91 |
92 | print(f"Updating version to {version}.")
93 | global_version_update(version)
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
100 | args = parser.parse_args()
101 | if not args.post_release:
102 | pre_release_work(patch=args.patch)
103 | elif args.patch:
104 | print("Nothing to do after a patch :-)")
105 | else:
106 | post_release_work()
107 |
--------------------------------------------------------------------------------
/alignment-handbook/src/alignment/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | from typing import Dict
17 |
18 | import torch
19 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
20 |
21 | from accelerate import Accelerator
22 | from huggingface_hub import list_repo_files
23 | from huggingface_hub.utils._validators import HFValidationError
24 | from peft import LoraConfig, PeftConfig
25 |
26 | from .configs import DataArguments, ModelArguments
27 | from .data import DEFAULT_CHAT_TEMPLATE
28 |
29 |
30 | def get_current_device() -> int:
31 | """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
32 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
33 |
34 |
35 | def get_kbit_device_map() -> Dict[str, int] | None:
36 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
37 | return {"": get_current_device()} if torch.cuda.is_available() else None
38 |
39 |
40 | def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
41 | if model_args.load_in_4bit:
42 | quantization_config = BitsAndBytesConfig(
43 | load_in_4bit=True,
44 | bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
45 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
46 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
47 | )
48 | elif model_args.load_in_8bit:
49 | quantization_config = BitsAndBytesConfig(
50 | load_in_8bit=True,
51 | )
52 | else:
53 | quantization_config = None
54 |
55 | return quantization_config
56 |
57 |
58 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
59 | """Get the tokenizer for the model."""
60 | tokenizer = AutoTokenizer.from_pretrained(
61 | model_args.model_name_or_path,
62 | revision=model_args.model_revision,
63 | trust_remote_code=model_args.trust_remote_code
64 | )
65 | if tokenizer.pad_token_id is None:
66 | tokenizer.pad_token_id = tokenizer.eos_token_id
67 |
68 | if data_args.truncation_side is not None:
69 | tokenizer.truncation_side = data_args.truncation_side
70 |
71 | # Set reasonable default for models without max length
72 | # if tokenizer.model_max_length > 100_000:
73 | # tokenizer.model_max_length = 4096
74 |
75 | if data_args.chat_template is not None:
76 | tokenizer.chat_template = data_args.chat_template
77 | elif tokenizer.chat_template is None:
78 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
79 |
80 | return tokenizer
81 |
82 |
83 | def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
84 | if model_args.use_peft is False:
85 | return None
86 |
87 | peft_config = LoraConfig(
88 | r=model_args.lora_r,
89 | lora_alpha=model_args.lora_alpha,
90 | lora_dropout=model_args.lora_dropout,
91 | bias="none",
92 | task_type="CAUSAL_LM",
93 | target_modules=model_args.lora_target_modules,
94 | modules_to_save=model_args.lora_modules_to_save,
95 | )
96 |
97 | return peft_config
98 |
99 |
100 | def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
101 | try:
102 | # Try first if model on a Hub repo
103 | repo_files = list_repo_files(model_name_or_path, revision=revision)
104 | except HFValidationError:
105 | # If not, check local repo
106 | repo_files = os.listdir(model_name_or_path)
107 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
108 |
--------------------------------------------------------------------------------
/validator_data/utils.py:
--------------------------------------------------------------------------------
1 | from sql_metadata import Parser
2 | import re
3 | import os
4 | import sqlparse
5 |
6 | def remove_table_alias(s):
7 | try:
8 | tables_aliases = Parser(s).tables_aliases
9 | except Exception as e:
10 | return s
11 |
12 | new_tables_aliases = {}
13 | for i in range(1,11):
14 | if "t{}".format(i) in tables_aliases.keys():
15 | new_tables_aliases["t{}".format(i)] = tables_aliases["t{}".format(i)]
16 |
17 | tables_aliases = new_tables_aliases
18 | for k, v in tables_aliases.items():
19 | # remove AS clauses
20 | s = s.replace("AS " + k + " ", "")
21 | # replace table alias with thier original names
22 | s = s.replace(k, v)
23 |
24 | return s
25 |
26 | def extract_select_clause(sql_query):
27 | # Define a regex pattern to match the SELECT clause up to the FROM keyword
28 | pattern = re.compile(r"SELECT\s.*?\s(?=FROM)", re.IGNORECASE | re.DOTALL)
29 |
30 | # Search for the pattern in the SQL query
31 | match = pattern.search(sql_query)
32 |
33 | if match:
34 | # Return the matched portion (SELECT clause)
35 | return match.group(0).strip()
36 | else:
37 | # Return None if no match is found
38 | return None
39 |
40 | def get_table_columns_list(schema):
41 | columns = []
42 | for table_data in schema['schema_items']:
43 | table_name = table_data['table_name'].lower()
44 | column_names = table_data['column_names']
45 | column_names = [x.lower() for x in column_names]
46 |
47 | for column in column_names:
48 | columns.append(f"{table_name}.{column}")
49 | columns.append(f"{table_name}.`{column}`")
50 | columns.append(column)
51 |
52 | columns = list(set(columns))
53 | return columns
54 |
55 | def get_columns_in_select_clause(sql_query, schema):
56 | column_list = get_table_columns_list(schema)
57 | select_clause = extract_select_clause(sql_query)
58 |
59 | if select_clause is None:
60 | return []
61 |
62 | select_clause = remove_table_alias(sqlparse.format(select_clause, keyword_case = "upper", identifier_case = "lower"))
63 | try:
64 | sql_tokens = [token.value for token in Parser(select_clause.lower()).tokens]
65 | except Exception as e:
66 | print(e)
67 | sql_tokens = sql_query.lower().split()
68 |
69 | select_columns = []
70 | for token in sql_tokens:
71 | if token in column_list:
72 | select_columns.append(token)
73 | return select_columns
74 |
75 |
76 | def get_equation_function_in_select_clause(sql_query):
77 | """
78 | equation function includes min, max, avg, sum, count, divide, +, /, case when
79 | """
80 | select_clause = extract_select_clause(sql_query)
81 | if select_clause is None:
82 | return []
83 | norm_select_clause = remove_table_alias(sqlparse.format(select_clause, keyword_case = "upper", identifier_case = "lower"))
84 |
85 | try:
86 | sql_tokens = [token.value for token in Parser(norm_select_clause.lower()).tokens]
87 | except Exception as e:
88 | sql_tokens = norm_select_clause.lower().split()
89 |
90 | equation_functions = []
91 | for token in sql_tokens:
92 | if token in ["min", "max", "avg", "sum", "count", "divide", "+", "/", "case", "when"]:
93 | equation_functions.append(token)
94 |
95 | return equation_functions
96 |
97 | def remove_tables_from_columns(columns):
98 | new_columns = []
99 | for col in columns:
100 | new_columns.append(col.split('.')[-1])
101 | return new_columns
102 |
103 | def check_columns_match(true_columns, pred_columns):
104 | true_columns = remove_tables_from_columns(true_columns)
105 | pred_columns = remove_tables_from_columns(pred_columns)
106 |
107 | # classify error types, unnecessary columns, missing columns, wrong order, return a string of error type
108 | if true_columns == pred_columns:
109 | return 'correct'
110 | else:
111 | if set(true_columns) == set(pred_columns):
112 | return 'incorrect: wrong order'
113 | elif set(true_columns) - set(pred_columns):
114 | return 'incorrect: missing columns ' + str(set(true_columns) - set(pred_columns))
115 | elif set(pred_columns) - set(true_columns):
116 | return 'incorrect: unnecessary columns ' + str(set(pred_columns) - set(true_columns))
--------------------------------------------------------------------------------
/data_processing/generate_validator_fixer_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | from datasets import Dataset, DatasetDict
5 | from planner import get_answer_llamacpp, get_answer_vllm, get_answer_openai
6 | from openai import OpenAI
7 | from dotenv import load_dotenv
8 | from planner import _make_str_response, _execute_sql, is_execution_correct
9 | import re
10 | from utils import norm_sql_query
11 | from tqdm import tqdm
12 | from multiprocessing import Pool
13 |
14 | # Set up argument parser
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--input_file', type=str, default='./data/multi-agents/fixed/gpt-4o-mini-validator-fixer-bird_with_evidence_train.jsonl')
17 | parser.add_argument('--output_dir', type=str, default='./data/multi-agents/fixed/sft-gpt-4o-mini-validator-fixer-bird_with_evidence_train')
18 | parser.add_argument('--num_processes', type=int, default=16)
19 | args = parser.parse_args()
20 |
21 | # Define the prompt template
22 | PROMPT = """{schema}
23 |
24 | Question: {question}
25 | External knowledge: {evidence}
26 |
27 | Generated SQL query: {sql_query}
28 |
29 | Execution response:
30 | {execution_response}
31 |
32 | Feedback for the SQL query:
33 | """
34 |
35 | COMPLETION = """
38 |
39 |
40 | {feedback_condition}
41 |
42 |
43 | FIXED SQL: {fixed_sql}"""
44 |
45 | def norm_feedback(feedback, token):
46 | feedback = token + feedback.split(token)[-1]
47 | return feedback
48 |
49 | def extract_sql_in_code_block(pred_sql_text):
50 | sql_block_match = re.search(r"```(.+?)```", pred_sql_text, re.DOTALL)
51 | if sql_block_match:
52 | sql_query = sql_block_match.group(1).strip()
53 | if sql_query.startswith("sql"):
54 | sql_query = sql_query.replace("sql", "").strip()
55 | return sql_query
56 | else:
57 | return pred_sql_text
58 |
59 | def process_sample(index_sample):
60 | index, sample = index_sample
61 | feedback_select = sample['validator_select'] or 'SELECT.\nNone'
62 | feedback_condition = sample['validator_condition'] or "CONDITION.\nNone"
63 | feedback_join = sample['validator_join'] or "JOIN.\nNone"
64 | feedback_join = "JOIN." + feedback_join.split("JOIN.")[-1]
65 |
66 | feedback_select = norm_feedback(feedback_select, "SELECT.")
67 | feedback_condition = norm_feedback(feedback_condition, "CONDITION.")
68 | feedback_join = norm_feedback(feedback_join, "JOIN.")
69 |
70 | prompt = PROMPT.format(
71 | schema=sample['schema_sequence'],
72 | question=sample['question'],
73 | evidence=sample['evidence'],
74 | sql_query=sample['predict_sql'],
75 | execution_response=sample['pred_result']
76 | )
77 |
78 | fixed_sql = sample['fixed_sql']
79 | if type(fixed_sql) == list:
80 | fixed_sql = fixed_sql[0]
81 |
82 | fixed_sql = extract_sql_in_code_block(fixed_sql)
83 |
84 | if fixed_sql != "None":
85 | true_result, has_error = _execute_sql("./" + sample["db_path"], sample["sql"])
86 | pred_result, has_error = _execute_sql("./" + sample["db_path"], fixed_sql)
87 |
88 | if not is_execution_correct(true_result, pred_result):
89 | print("-"*20)
90 | print('True:', true_result)
91 | print('Pred:', pred_result)
92 | # completion = norm_sql_query(sample['sql'], sample['schema'])
93 | fixed_sql = sample['sql']
94 |
95 | completion = COMPLETION.format(
96 | feedback_select=feedback_select,
97 | feedback_condition=feedback_condition,
98 | # feedback_join=feedback_join,
99 | fixed_sql=fixed_sql
100 | )
101 |
102 | return {
103 | 'prompt_id': str(index),
104 | 'messages': {
105 | 'prompt': prompt,
106 | 'completion': completion
107 | }
108 | }
109 |
110 | def main():
111 | with open(args.input_file) as fp:
112 | data = [json.loads(line) for line in fp]
113 |
114 | with Pool(processes=args.num_processes) as pool:
115 | results = list(tqdm(pool.imap(process_sample, enumerate(data)), total=len(data)))
116 |
117 | sft_data = [result for result in results if result is not None]
118 |
119 | dataset = DatasetDict({
120 | 'train': Dataset.from_list(sft_data),
121 | 'test': Dataset.from_list(sft_data[:100]),
122 | })
123 | dataset.save_to_disk(args.output_dir)
124 | print(dataset)
125 |
126 | if __name__ == "__main__":
127 | main()
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MATS: A Multi-agent Text2SQL Framework using Small Language Models and Execution Feedback
2 |
3 | 
4 |
5 | MATS is a multi-agent framework for Text2SQL using small language models and execution feedback to improve query accuracy. It employs multiple specialized agents—including schema insight agent, planner, validator, fix agent, and selection agent. Some components of this framework are adapted from [CodeS](https://github.com/RUCKBReasoning/codes) (for schema filtering) and [alignment-handbook](https://github.com/huggingface/alignment-handbook) (for supervised fine-tuning and ORPO training).
6 |
7 | **1. To set up the environment**
8 | ```
9 | conda env create -n mats -f environment.yml
10 | conda activate mats
11 | ```
12 |
13 | **2. Run Evaluation on BIRD**:
14 |
15 | First serve the models with VLLM.
16 | ```
17 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-bird-planner --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.3 --enable-prefix-caching
18 |
19 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-bird-validator --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.2 --enable-prefix-caching
20 |
21 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-bird-fixed --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.2 --enable-prefix-caching
22 |
23 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-bird-selection --host 0.0.0.0 --port 8006 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name selection --gpu-memory-utilization 0.3 --enable-prefix-caching
24 | ```
25 |
26 | Run evaluation:
27 | ```
28 | eval_file=data/evaluate/orpo-llama-3-iter-2-end2end-bird_dev.jsonl
29 | rm $eval_file
30 | PYTHONPATH=. python evaluate_end2end.py \
31 | --input_file data/schema_insight_bird_with_evidence_dev_text2sql.json \
32 | --output_file $eval_file \
33 | --model-name llama --mode test --n_return 10 --temperature 1.0 --api_host http://localhost:8003 --n_processes 16
34 |
35 | python compute_acc.py --pred_file $eval_file
36 | ```
37 |
38 |
39 |
40 |
41 | **3. To run evaluation on Spider**:
42 |
43 | First serve the models with VLLM.
44 | ```
45 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-spider-planner --host 0.0.0.0 --port 8003 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name planner --gpu-memory-utilization 0.3 --enable-prefix-caching
46 |
47 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-spider-validator --host 0.0.0.0 --port 8004 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name validator --gpu-memory-utilization 0.2 --enable-prefix-caching
48 |
49 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-1b-spider-fixed --host 0.0.0.0 --port 8005 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name fixed --gpu-memory-utilization 0.2 --enable-prefix-caching
50 |
51 | CUDA_VISIBLE_DEVICES=0 vllm serve thanhdathoang/llama-3b-spider-selection --host 0.0.0.0 --port 8006 --dtype bfloat16 --max-model-len 4096 --disable-log-requests --served-model-name selection --gpu-memory-utilization 0.3 --enable-prefix-caching
52 | ```
53 |
54 | Run evaluation:
55 | ```
56 | eval_file=data/evaluate/orpo-llama-3-iter-2-end2end-spider_dev.jsonl
57 | rm $eval_file
58 | PYTHONPATH=. python evaluate_end2end.py \
59 | --input_file data/schema_insight_spider_dev_text2sql.json \
60 | --output_file $eval_file \
61 | --model-name llama --mode test --n_return 10 --temperature 1.0 --api_host http://localhost:8003 --n_processes 16
62 |
63 | python compute_acc.py --pred_file $eval_file
64 | ```
65 |
66 |
67 | **4. For training agents**
68 |
69 | The Schema Filtering is inherited from [CodeS](https://github.com/RUCKBReasoning/codes).
70 |
71 | To train other agents, see the code in ***alignment-handbook/***, here we modified the repository [alignment-handbook](https://github.com/huggingface/alignment-handbook) for supervised-finetuning and ORPO on the completion part only. The config files could be found in **alignment-handbook/recipes/**.
72 |
73 | **Note**: Currently this work is under review. The model and training dataset will be publicly available upon acceptance.
74 |
75 |
76 | -----------
77 | **Backup Statistics**
78 |
79 | 
80 |
--------------------------------------------------------------------------------
/data_processing/generate_sft_data_for_validator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from datasets import Dataset, DatasetDict
4 | import numpy as np
5 | from planner import _make_str_response, _execute_sql, is_execution_correct
6 | from multiprocessing import Pool, cpu_count
7 | from functools import partial
8 | from tqdm import tqdm
9 | from multiprocessing import Manager
10 |
11 | # Set seed
12 | np.random.seed(100)
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--input_file', type=str, default='../data/llm_alignment/validator_join_bird_with_evidence_train.jsonl')
16 | parser.add_argument('--output_dir', type=str, default='../data/llm_alignment/sft_data_for_validator_join_bird_with_evidence_train')
17 | args = parser.parse_args()
18 |
19 | def norm_completion(completion):
20 | """Normalize the completion by removing unwanted lines."""
21 | lines = completion.split('\n')
22 | filter_lines = [line for line in lines if "broaden the criteria" not in line]
23 | completion = "\n".join(filter_lines)
24 | return completion
25 |
26 | PROMPT = """Generate feedbacks to fix the following SQL query:
27 | {schema}
28 |
29 | Question: {question}
30 | External knowledge: {evidence}
31 |
32 | SQL query: {sql_query}
33 |
34 | Execution response:
35 | {execution_response}
36 |
37 | Feedback:"""
38 |
39 | def process_sample(sample, index, input_file):
40 | """Process a single sample from the dataset."""
41 | key, token = None, None
42 | if '_join' in input_file:
43 | key = 'validator_join'
44 | token = "JOIN."
45 | elif '_select' in input_file:
46 | key = 'validator_select'
47 | token = "SELECT."
48 | elif '_order' in input_file:
49 | key = 'validator_order'
50 | token = "ORDER BY."
51 | elif '_condition' in input_file:
52 | key = 'validator_condition'
53 | token = "CONDITION."
54 |
55 | feedback = sample.get(key)
56 | if not feedback:
57 | return None
58 |
59 | if type(feedback) == list:
60 | feedback = feedback[0]
61 |
62 | if f"prompt_{key}" in sample:
63 | prompt = sample[f"prompt_{key}"]
64 | prompt_completion = prompt + feedback
65 | else:
66 | prompt_completion = "\n" + feedback
67 |
68 | feedback = token + prompt_completion.split(token)[-1]
69 | prompt = PROMPT.format(schema=sample['schema_sequence'],
70 | question=sample['question'],
71 | evidence=sample['evidence'],
72 | sql_query=sample['predict_sql'],
73 | execution_response=sample['pred_result'])
74 |
75 | completion = feedback
76 | if isinstance(completion, list):
77 | completion = completion[0]
78 |
79 | completion = norm_completion(completion)
80 | prompt_id = f"{index}"
81 |
82 | true_result, _ = _execute_sql("./" + sample["db_path"], sample["sql"])
83 | pred_result, _ = _execute_sql("./" + sample["db_path"], sample['predict_sql'])
84 |
85 | is_pred_sql_correct = is_execution_correct(true_result, pred_result)
86 | feedback_conclude_correct = completion is None or 'Conclude: correct' in completion
87 |
88 | if is_pred_sql_correct and not feedback_conclude_correct: # bad case
89 | return None
90 |
91 | return {
92 | 'prompt_id': prompt_id,
93 | 'messages': {
94 | 'prompt': prompt,
95 | 'completion': completion
96 | }
97 | }
98 |
99 | def main():
100 | # Load JSONL data
101 | with open(args.input_file, 'r') as f:
102 | data = [json.loads(line) for line in f]
103 |
104 | # Use multiprocessing to process samples with tqdm
105 | print('Start processing samples...')
106 | with Manager() as manager:
107 | results_list = manager.list() # Shared list for results
108 | total_samples = len(data)
109 |
110 | with tqdm(total=total_samples, desc="Processing Samples") as pbar:
111 | def update_progress(result):
112 | if result is not None:
113 | results_list.append(result)
114 | pbar.update()
115 |
116 | with Pool(processes=24) as pool:
117 | partial_process_sample = partial(process_sample, input_file=args.input_file)
118 | for idx, sample in enumerate(data):
119 | pool.apply_async(partial_process_sample, args=(sample, idx), callback=update_progress)
120 |
121 | pool.close()
122 | pool.join()
123 |
124 | # Convert results to a normal list
125 | sft_data = list(results_list)
126 |
127 | # Shuffle and create DatasetDict
128 | np.random.shuffle(sft_data)
129 |
130 | dataset = DatasetDict({
131 | 'train': Dataset.from_list(sft_data),
132 | 'test': Dataset.from_list(sft_data[:100]),
133 | })
134 |
135 | print(dataset)
136 | dataset.save_to_disk(args.output_dir)
137 |
138 | if __name__ == "__main__":
139 | main()
140 |
--------------------------------------------------------------------------------
/data_processing/generate_sft_data_for_fix.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | from datasets import Dataset, DatasetDict
5 | from planner import get_answer_llamacpp, get_answer_vllm, get_answer_openai
6 | from openai import OpenAI
7 | from dotenv import load_dotenv
8 | from planner import _make_str_response, _execute_sql, is_execution_correct
9 | import re
10 | from utils import norm_sql_query
11 | from tqdm import tqdm
12 | from multiprocessing import Pool
13 |
14 | # Set up argument parser
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--input_file', type=str, default='data/multi-agents/fixed/gpt-4o-mini-fixed-bird_with_evidence_train.jsonl')
17 | parser.add_argument('--output_dir', type=str, default='./data/multi-agents/fixed/sft-gpt-4o-mini-fixed-bird_with_evidence_train')
18 | parser.add_argument('--num_processes', type=int, default=16)
19 | args = parser.parse_args()
20 |
21 | # Define the prompt template
22 | PROMPT = """{schema}
23 |
24 | Question: {question}
25 | External knowledge: {evidence}
26 |
27 | Generated SQL query: {sql_query}
28 |
29 | Execution response:
30 | {execution_response}
31 |
32 | Feedback for the SQL query:
33 | {feedback_select}
34 |
35 | {feedback_condition}
36 |
37 | {feedback_join}
38 |
39 | FIXED SQL:"""
40 |
41 | def norm_feedback(feedback, token):
42 | feedback = token + feedback.split(token)[-1]
43 | return feedback
44 |
45 | def extract_sql_in_code_block(pred_sql_text):
46 | sql_block_match = re.search(r"```(.+?)```", pred_sql_text, re.DOTALL)
47 | if sql_block_match:
48 | sql_query = sql_block_match.group(1).strip()
49 | if sql_query.startswith("sql"):
50 | sql_query = sql_query.replace("sql", "").strip()
51 | return sql_query
52 | else:
53 | return pred_sql_text
54 |
55 | def process_sample(index_sample):
56 | index, sample = index_sample
57 | feedback_select = sample['validator_select'] or 'SELECT.\nNone'
58 | feedback_condition = sample['validator_condition'] or "CONDITION.\nNone"
59 | feedback_join = sample['validator_join'] or "JOIN.\nNone"
60 | feedback_join = "JOIN." + feedback_join.split("JOIN.")[-1]
61 | feedback_order = sample['validator_order'] or "ORDER BY.\nNone"
62 |
63 | feedback_select = norm_feedback(feedback_select, "SELECT.")
64 | feedback_condition = norm_feedback(feedback_condition, "CONDITION.")
65 | feedback_join = norm_feedback(feedback_join, "JOIN.")
66 | feedback_order = norm_feedback(feedback_order, "ORDER BY.")
67 |
68 | select_correct = 'Conclude: correct' in feedback_select or feedback_select == 'SELECT.\nNone'
69 | condition_correct = 'Conclude: correct' in feedback_condition or feedback_condition == 'CONDITION.\nNone'
70 | join_correct = 'Conclude: correct' in feedback_join or feedback_join == 'JOIN.\nNone'
71 | order_correct = 'Conclude: correct' in feedback_order or feedback_order == 'ORDER BY.\nNone'
72 |
73 | if select_correct:
74 | feedback_select = ""
75 | if condition_correct:
76 | feedback_condition = ""
77 | if join_correct:
78 | feedback_join = ""
79 | if order_correct:
80 | feedback_order = ""
81 |
82 | # if select_correct and condition_correct and join_correct and order_correct:
83 | # return None
84 |
85 | prompt = PROMPT.format(
86 | schema=sample['schema_sequence'],
87 | question=sample['question'],
88 | evidence=sample['evidence'],
89 | sql_query=sample['predict_sql'],
90 | execution_response=sample['pred_result'],
91 | feedback_select=feedback_select,
92 | feedback_condition=feedback_condition,
93 | feedback_join=feedback_join,
94 | # feedback_order=feedback_order
95 | )
96 |
97 | completion = sample['fixed_sql']
98 | if type(completion) == list:
99 | completion = completion[0]
100 |
101 | fixed_sql = extract_sql_in_code_block(completion)
102 | # completion = norm_sql_query(fixed_sql, sample['schema'])
103 | completion = fixed_sql
104 |
105 | true_result, has_error = _execute_sql("./" + sample["db_path"], sample["sql"])
106 | pred_result, has_error = _execute_sql("./" + sample["db_path"], fixed_sql)
107 |
108 | if not is_execution_correct(true_result, pred_result):
109 | print("-"*20)
110 | print('True:', true_result)
111 | print('Pred:', pred_result)
112 | # completion = norm_sql_query(sample['sql'], sample['schema'])
113 | completion = sample['sql']
114 | # return None
115 |
116 | return {
117 | 'prompt_id': str(index),
118 | 'messages': {
119 | 'prompt': prompt,
120 | 'completion': completion
121 | }
122 | }
123 |
124 | def main():
125 | with open(args.input_file) as fp:
126 | data = [json.loads(line) for line in fp]
127 |
128 | with Pool(processes=args.num_processes) as pool:
129 | results = list(tqdm(pool.imap(process_sample, enumerate(data)), total=len(data)))
130 |
131 | sft_data = [result for result in results if result is not None]
132 |
133 | dataset = DatasetDict({
134 | 'train': Dataset.from_list(sft_data),
135 | 'test': Dataset.from_list(sft_data[:100]),
136 | })
137 | dataset.save_to_disk(args.output_dir)
138 | print(dataset)
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/alignment-handbook/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16 |
17 |
18 | import re
19 | import shutil
20 | from pathlib import Path
21 |
22 | from setuptools import find_packages, setup
23 |
24 |
25 | # Remove stale alignment.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26 | stale_egg_info = Path(__file__).parent / "alignment.egg-info"
27 | if stale_egg_info.exists():
28 | print(
29 | (
30 | "Warning: {} exists.\n\n"
31 | "If you recently updated alignment, this is expected,\n"
32 | "but it may prevent alignment from installing in editable mode.\n\n"
33 | "This directory is automatically generated by Python's packaging tools.\n"
34 | "I will remove it now.\n\n"
35 | "See https://github.com/pypa/pip/issues/5466 for details.\n"
36 | ).format(stale_egg_info)
37 | )
38 | shutil.rmtree(stale_egg_info)
39 |
40 |
41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43 | _deps = [
44 | "accelerate==0.23.0",
45 | "bitsandbytes==0.41.2.post2",
46 | "black==23.1.0",
47 | "datasets==2.14.6",
48 | "deepspeed==0.12.2",
49 | "einops>=0.6.1",
50 | "evaluate==0.4.0",
51 | "flake8>=6.0.0",
52 | "hf-doc-builder>=0.4.0",
53 | "huggingface-hub>=0.14.1,<1.0",
54 | "isort>=5.12.0",
55 | "ninja>=1.11.1",
56 | "numpy>=1.24.2",
57 | "packaging>=23.0",
58 | "parameterized>=0.9.0",
59 | "peft==0.6.1",
60 | "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
61 | "pytest",
62 | "safetensors>=0.3.3",
63 | "scipy",
64 | "tensorboard",
65 | "torch==2.1.0",
66 | "transformers==4.35.0",
67 | "trl==0.7.4",
68 | "jinja2>=3.0.0",
69 | "tqdm>=4.64.1",
70 | ]
71 |
72 | # this is a lookup table with items like:
73 | #
74 | # tokenizers: "tokenizers==0.9.4"
75 | # packaging: "packaging"
76 | #
77 | # some of the values are versioned whereas others aren't.
78 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
79 |
80 |
81 | def deps_list(*pkgs):
82 | return [deps[pkg] for pkg in pkgs]
83 |
84 |
85 | extras = {}
86 | extras["tests"] = deps_list("pytest", "parameterized")
87 | extras["torch"] = deps_list("torch")
88 | extras["quality"] = deps_list("black", "isort", "flake8")
89 | extras["docs"] = deps_list("hf-doc-builder")
90 | extras["dev"] = extras["docs"] + extras["quality"] + extras["tests"]
91 |
92 | # core dependencies shared across the whole project - keep this to a bare minimum :)
93 | install_requires = [
94 | deps["accelerate"],
95 | deps["bitsandbytes"],
96 | deps["einops"],
97 | deps["evaluate"],
98 | deps["datasets"],
99 | deps["deepspeed"],
100 | deps["huggingface-hub"],
101 | deps["jinja2"],
102 | deps["ninja"],
103 | deps["numpy"],
104 | deps["packaging"], # utilities from PyPA to e.g., compare versions
105 | deps["peft"],
106 | deps["protobuf"],
107 | deps["safetensors"],
108 | deps["scipy"],
109 | deps["tensorboard"],
110 | deps["tqdm"], # progress bars in model download and training scripts
111 | deps["transformers"],
112 | deps["trl"],
113 | ]
114 |
115 | setup(
116 | name="alignment-handbook",
117 | version="0.2.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
118 | author="The Hugging Face team (past and future)",
119 | author_email="lewis@huggingface.co",
120 | description="The Alignment Handbook",
121 | long_description=open("README.md", "r", encoding="utf-8").read(),
122 | long_description_content_type="text/markdown",
123 | keywords="nlp deep learning rlhf llm",
124 | license="Apache",
125 | url="https://github.com/huggingface/alignment-handbook",
126 | package_dir={"": "src"},
127 | packages=find_packages("src"),
128 | zip_safe=False,
129 | extras_require=extras,
130 | python_requires=">=3.10.9",
131 | install_requires=install_requires,
132 | classifiers=[
133 | "Development Status :: 3 - Alpha",
134 | "Intended Audience :: Developers",
135 | "Intended Audience :: Education",
136 | "Intended Audience :: Science/Research",
137 | "License :: OSI Approved :: Apache Software License",
138 | "Operating System :: OS Independent",
139 | "Programming Language :: Python :: 3",
140 | "Programming Language :: Python :: 3.10",
141 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
142 | ],
143 | )
144 |
--------------------------------------------------------------------------------
/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from typing import List
4 |
5 | from torch.optim.lr_scheduler import _LRScheduler
6 | from torch.optim import Optimizer
7 |
8 | '''
9 | copy from the source code of pl_bolts
10 | '''
11 |
12 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
13 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
14 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
15 |
16 | .. warning::
17 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
18 | after each iteration as calling it after each epoch will keep the starting lr at
19 | warmup_start_lr for the first epoch which is 0 in most cases.
20 |
21 | .. warning::
22 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
23 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
24 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
25 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
26 | train and validation methods.
27 |
28 | Example:
29 | >>> import torch.nn as nn
30 | >>> from torch.optim import Adam
31 | >>> #
32 | >>> layer = nn.Linear(10, 1)
33 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
34 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
35 | >>> # the default case
36 | >>> for epoch in range(40):
37 | ... # train(...)
38 | ... # validate(...)
39 | ... scheduler.step()
40 | >>> # passing epoch param case
41 | >>> for epoch in range(40):
42 | ... scheduler.step(epoch)
43 | ... # train(...)
44 | ... # validate(...)
45 | """
46 |
47 | def __init__(
48 | self,
49 | optimizer: Optimizer,
50 | warmup_epochs: int,
51 | max_epochs: int,
52 | warmup_start_lr: float = 0.0,
53 | eta_min: float = 0.0,
54 | last_epoch: int = -1,
55 | ) -> None:
56 | """
57 | Args:
58 | optimizer (Optimizer): Wrapped optimizer.
59 | warmup_epochs (int): Maximum number of iterations for linear warmup
60 | max_epochs (int): Maximum number of iterations
61 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
62 | eta_min (float): Minimum learning rate. Default: 0.
63 | last_epoch (int): The index of last epoch. Default: -1.
64 | """
65 | self.warmup_epochs = warmup_epochs
66 | self.max_epochs = max_epochs
67 | self.warmup_start_lr = warmup_start_lr
68 | self.eta_min = eta_min
69 |
70 | super().__init__(optimizer, last_epoch)
71 |
72 | def get_lr(self) -> List[float]:
73 | """Compute learning rate using chainable form of the scheduler."""
74 | if not self._get_lr_called_within_step:
75 | warnings.warn(
76 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
77 | UserWarning,
78 | )
79 |
80 | if self.last_epoch == 0:
81 | return [self.warmup_start_lr] * len(self.base_lrs)
82 | if self.last_epoch < self.warmup_epochs:
83 | return [
84 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
85 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
86 | ]
87 | if self.last_epoch == self.warmup_epochs:
88 | return self.base_lrs
89 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
90 | return [
91 | group["lr"]
92 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
93 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
94 | ]
95 |
96 | return [
97 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
98 | / (
99 | 1
100 | + math.cos(
101 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
102 | )
103 | )
104 | * (group["lr"] - self.eta_min)
105 | + self.eta_min
106 | for group in self.optimizer.param_groups
107 | ]
108 |
109 | def _get_closed_form_lr(self) -> List[float]:
110 | """Called when epoch is passed as a param to the `step` function of the scheduler."""
111 | if self.last_epoch < self.warmup_epochs:
112 | return [
113 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
114 | for base_lr in self.base_lrs
115 | ]
116 |
117 | return [
118 | self.eta_min
119 | + 0.5
120 | * (base_lr - self.eta_min)
121 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
122 | for base_lr in self.base_lrs
123 | ]
124 |
--------------------------------------------------------------------------------
/utils/load_sft_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import gc
4 |
5 | from datasets import Dataset
6 | from torch.utils.data import Dataset
7 | from schema_item_filter import SchemaItemClassifierInference, filter_schema, filter_schema_purple
8 | from utils.db_utils import get_db_schema_sequence, get_matched_content_sequence
9 |
10 | def prepare_text2sql_prefix_sequence(data):
11 | prompt = f"""Convert the question to SQL query.
12 | {data["schema_sequence"]}
13 | {data["content_sequence"]}
14 | question: {data["text"]}"""
15 | return prompt
16 |
17 | def prepare_inputs_and_labels(prefix_seq, target_seq, tokenizer, max_tokens):
18 | prefix_ids = [tokenizer.bos_token_id] + tokenizer(prefix_seq , truncation = False)["input_ids"]
19 | target_ids = tokenizer(target_seq, truncation = False)["input_ids"] + [tokenizer.eos_token_id]
20 |
21 | seq_length = len(prefix_ids) + len(target_ids)
22 | if seq_length <= max_tokens: # pad inputs with pad_token_id
23 | pad_length = max_tokens - seq_length
24 | input_ids = prefix_ids + target_ids + [tokenizer.pad_token_id] * pad_length
25 | # tell the model to ignore the padding tokens when performing (masked) self-attention
26 | attention_mask = [1] * seq_length + [0] * pad_length
27 | # only target_ids produces gradients
28 | labels = [-100] * len(prefix_ids) + target_ids + [-100] * pad_length
29 | else: # no padding
30 | print("the current input sequence exceeds the max_tokens, we will truncate it.")
31 | input_ids = prefix_ids + target_ids
32 | # pre-truncate input ids
33 | input_ids = [tokenizer.bos_token_id] + input_ids[-(max_tokens-1):]
34 | attention_mask = [1] * max_tokens
35 | # only target_ids produces gradients
36 | labels = [-100] * len(prefix_ids) + target_ids
37 | # pre-truncate labels
38 | labels = labels[-max_tokens:]
39 |
40 | return {
41 | "input_ids": torch.tensor(input_ids, dtype = torch.int64),
42 | "attention_mask": torch.tensor(attention_mask, dtype = torch.int64),
43 | "labels": torch.tensor(labels, dtype = torch.int64)
44 | }
45 |
46 | # def prepare_inputs(prefix_seq, tokenizer, max_prefix_length):
47 | # input_ids = [tokenizer.bos_token_id] + tokenizer(prefix_seq , truncation = False)["input_ids"]
48 |
49 | # if len(input_ids) > max_prefix_length:
50 | # print("the current input sequence exceeds the max_tokens, we will truncate it.")
51 | # input_ids = [tokenizer.bos_token_id] + input_ids[-(max_prefix_length-1):]
52 |
53 | # attention_mask = [1] * len(input_ids)
54 |
55 | # return {
56 | # "input_ids": torch.tensor(input_ids, dtype = torch.int64),
57 | # "attention_mask": torch.tensor(attention_mask, dtype = torch.int64)
58 | # }
59 |
60 | def prepare_inputs(prefix_seq, tokenizer, max_prefix_length):
61 | messages = [{
62 | 'role': 'user',
63 | 'content': prefix_seq
64 | }]
65 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66 |
67 | input_ids = tokenizer(prompt , truncation = False)["input_ids"]
68 |
69 | if len(input_ids) > max_prefix_length:
70 | print("the current input sequence exceeds the max_tokens, we will truncate it.")
71 | input_ids = input_ids[-(max_prefix_length-1):]
72 |
73 | attention_mask = [1] * len(input_ids)
74 |
75 | return {
76 | "input_ids": torch.tensor(input_ids, dtype = torch.int64),
77 | "attention_mask": torch.tensor(attention_mask, dtype = torch.int64)
78 | }
79 |
80 | class SFTSQLGenerationDataset(Dataset):
81 | def __init__(self, text2sql_data_dir, tokenizer, max_tokens, mode, table_num, column_num, threshold, sic_path, do_filter_schema=True):
82 | super().__init__()
83 | dataset = json.load(open(text2sql_data_dir))
84 |
85 | print("apply filtering strategies...")
86 | if do_filter_schema:
87 | if mode == "train":
88 | dataset = filter_schema(dataset, "train", None, table_num, column_num, threshold=threshold)
89 | elif mode == "eval":
90 | sic = SchemaItemClassifierInference(sic_path)
91 | dataset = filter_schema(dataset, "eval", sic, table_num, column_num, threshold=threshold)
92 | # dataset = filter_schema_purple(dataset, "eval", "/home/datht/llmsql/selector/spider-dev/spider-dev-selector-t0.02-value-samples.json")
93 | del sic
94 | torch.cuda.empty_cache()
95 |
96 | # prepare schema sequence and content sequence
97 | for data in dataset:
98 | data["schema_sequence"] = get_db_schema_sequence(data["schema"])
99 | # data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
100 |
101 | self.mode = mode
102 | self.dataset = dataset
103 | self.tokenizer = tokenizer
104 | self.max_tokens = max_tokens
105 |
106 | def __getitem__(self, index):
107 | data = self.dataset[index]
108 | prefix_seq = prepare_text2sql_prefix_sequence(data)
109 | if index < 2:
110 | print(prefix_seq)
111 |
112 | if self.mode == "train":
113 | target_seq = data["sql"]
114 | return prepare_inputs_and_labels(prefix_seq, target_seq, self.tokenizer, self.max_tokens)
115 | elif self.mode == "eval":
116 | return prepare_inputs(prefix_seq, self.tokenizer, self.max_tokens)
117 |
118 | def __len__(self):
119 | return len(self.dataset)
--------------------------------------------------------------------------------
/data_processing/generate_sft_data_for_planner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import re
5 | from datasets import Dataset, DatasetDict
6 | from tqdm import tqdm
7 | import sqlite3
8 | from func_timeout import func_timeout, FunctionTimedOut
9 | from planner import _make_str_response, _execute_sql, is_execution_correct
10 | from utils import norm_sql_query
11 | from multiprocessing import Pool
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--input_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train.jsonl')
15 | parser.add_argument('--raw_train_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train.jsonl')
16 | parser.add_argument('--output_dir', type=str, default='../data/multi-agents/planner/sft-gpt-4o-mini-planner_combine_bird_with_evidence_train/')
17 | parser.add_argument('--error_file', type=str, default='../data/multi-agents/planner/gpt-4o-mini-planner_combine_bird_with_evidence_train-error-turn-1.jsonl')
18 | parser.add_argument('--use_groundtruth', action='store_true')
19 | parser.add_argument('--no_filter', action='store_true')
20 | args = parser.parse_args()
21 |
22 | PROMPT = """{schema}
23 |
24 | Question: {question}
25 | External knowledge: {evidence}
26 |
27 | Planning:
28 | """
29 | # PROMPT = """{schema}
30 |
31 | # Question: {question}
32 | # """
33 |
34 | # Helper function for processing each sample
35 | def process_sample(args):
36 | isample, sample, raw_sample, use_groundtruth, no_filter = args
37 | schema = raw_sample['schema_sequence']
38 | question = sample['question']
39 | evidence = sample['evidence']
40 |
41 | key = 'planner_combine_with_true_sql'
42 | feedback = sample[key]
43 | if feedback is None or len(feedback) == 0:
44 | return None, None # Indicate empty result
45 |
46 | if isinstance(feedback, list):
47 | feedback = feedback[0]
48 |
49 | prompt = PROMPT.format(schema=schema, question=question, evidence=evidence)
50 |
51 | if use_groundtruth:
52 | completion = sample['sql']
53 | # completion = norm_sql_query(sample['sql'], raw_sample['schema'])
54 | else:
55 | # Extract SQL query using regex
56 | pred_sql_match = re.search(r"(?<=Final SQL query:).*?```(.*?)```", feedback, re.DOTALL)
57 | if pred_sql_match is None:
58 | pred_sql = " "
59 | else:
60 | pred_sql = pred_sql_match.group(1).strip()
61 | if pred_sql.startswith("sql"):
62 | pred_sql = pred_sql[3:].strip()
63 |
64 | # norm_pred_sql = norm_sql_query(pred_sql, raw_sample['schema'])
65 | # feedback = feedback.replace(pred_sql, norm_pred_sql)
66 |
67 | if not no_filter:
68 | true_result, has_error_true = _execute_sql("./" + sample["db_path"], sample["sql"])
69 | pred_result, has_error_pred = _execute_sql("./" + sample["db_path"], pred_sql)
70 | # norm_pred_result, has_error_pred = _execute_sql("./" + sample["db_path"], norm_pred_sql)
71 |
72 | # if not is_execution_correct(pred_result, norm_pred_result):
73 | # # print to debug
74 | # print("-" * 20)
75 | # print("Norm SQL:", norm_pred_sql)
76 | # print("Pred SQL:", pred_sql)
77 | # print("Norm Result:", norm_pred_result)
78 | # print("Pred Result:", pred_result)
79 |
80 | if not is_execution_correct(true_result, pred_result):
81 | # sample['true_result'] = _make_str_response(true_result, has_error_true)
82 | # sample['pred_result'] = _make_str_response(pred_result, has_error_pred)
83 | return None, sample # Return sample with error
84 |
85 | completion = feedback if not isinstance(feedback, list) else feedback[0]
86 | prompt_id = f"{isample}"
87 |
88 | return {
89 | 'prompt_id': prompt_id,
90 | 'messages': {
91 | 'prompt': prompt,
92 | 'completion': completion
93 | }
94 | }, None # Indicate valid result
95 |
96 |
97 | if __name__ == "__main__":
98 | # Load data from input files
99 | data = []
100 | with open(args.input_file, 'r') as f:
101 | for line in f:
102 | data.append(json.loads(line))
103 |
104 | raw_data = json.load(open(args.raw_train_file))
105 |
106 | # Prepare arguments for each sample to process
107 | samples_args = [(i, data[i], raw_data[i], args.use_groundtruth, args.no_filter) for i in range(len(data))]
108 |
109 | # Run parallel processing with 24 processes
110 | sft_data = []
111 | error_data = []
112 | with Pool(24) as pool:
113 | for result, error in tqdm(pool.imap_unordered(process_sample, samples_args), total=len(data)):
114 | if result:
115 | sft_data.append(result)
116 | if error:
117 | error_data.append(error)
118 | # for sample_arg in tqdm(samples_args):
119 | # result, error = process_sample(sample_arg)
120 | # if result:
121 | # sft_data.append(result)
122 | # if error:
123 | # error_data.append(error)
124 |
125 | # Create datasets
126 | dataset = DatasetDict({
127 | 'train': Dataset.from_list(sft_data),
128 | 'test': Dataset.from_list(sft_data[:100]),
129 | })
130 | print(dataset)
131 |
132 | # Save the dataset
133 | dataset.save_to_disk(args.output_dir)
134 |
135 | # Write error data to JSONL file
136 | with open(args.error_file, 'w') as output_file:
137 | for sample in error_data:
138 | output_file.write(json.dumps(sample, ensure_ascii=False) + '\n')
139 |
--------------------------------------------------------------------------------
/llm_alignment/build_rlef_selection_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle as pkl
3 | import argparse
4 | import numpy as np
5 | import requests
6 | from tqdm import tqdm
7 | from copy import deepcopy
8 | from multiprocessing import Pool, cpu_count
9 | from data_processing.planner import SelectionAgentWithSchema
10 | import os
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--pred_file", type=str, default='logs/results-orpo-iter-2-bird-train-top-20-temperature-1.0.pkl')
14 | parser.add_argument("--max_candidates", type=int, default=3)
15 | parser.add_argument("--progress_file", type=str, default='temp/bird_selection_dpo.jsonl')
16 | args = parser.parse_args()
17 |
18 | # Initialize Selection Agent
19 | selection_agent = SelectionAgentWithSchema()
20 |
21 | def get_answer_selection(messages):
22 | response = requests.post(
23 | "http://192.168.1.108:8006/v1/completions",
24 | json={
25 | "model": 'selection',
26 | "prompt": messages[0]['content'],
27 | "max_tokens": 512,
28 | "use_beam_search": False,
29 | "n": 20,
30 | "temperature": 1.0,
31 | "stop": ['<|eot_id|>', '<|end|>', '<|end_header_id|>', '<|end_of_text|>', '<|end▁of▁sentence|>']
32 | }
33 | ).json()
34 |
35 | try:
36 | return [x['text'] for x in response['choices']]
37 | except:
38 | print(response)
39 | return []
40 |
41 | selection_agent.get_answer = get_answer_selection
42 |
43 | # Load predictions
44 | preds = pkl.load(open(args.pred_file, 'rb'))
45 |
46 | # Load progress from previous runs
47 | processed_keys = {}
48 | if os.path.exists(args.progress_file):
49 | with open(args.progress_file, 'r', encoding='utf-8') as f:
50 | for line in f:
51 | sample = json.loads(line.strip())
52 | key = (sample["db_id"], sample["question"])
53 | processed_keys[key] = processed_keys.get(key, 0) + 1
54 |
55 | # Expand preds 4 times and filter already processed ones
56 | all_preds = preds * 4
57 | filtered_preds = []
58 | for sample in all_preds:
59 | key = (sample["db_id"], sample["question"])
60 | if processed_keys.get(key, 0) < 4:
61 | filtered_preds.append(sample)
62 | processed_keys[key] = processed_keys.get(key, 0) + 1 # Track count
63 |
64 | def build_dpo_data(sample):
65 | """Process a single sample and return DPO data."""
66 | sample = deepcopy(sample)
67 |
68 | # Filter out samples with execution failures
69 | valid_sqls, valid_results, valid_corrects = [], [], []
70 | for i in range(min(len(sample['predict_sqls']), 20)):
71 | if 'Execution failed' not in sample['pred_results'][i] and 'too much time' not in sample['pred_results'][i]:
72 | valid_sqls.append(sample['predict_sqls'][i])
73 | valid_results.append(sample['pred_results'][i])
74 | valid_corrects.append(sample['is_execution_corrects'][i])
75 |
76 | sample['predict_sqls'] = valid_sqls
77 | sample['pred_results'] = valid_results
78 | sample['is_execution_corrects'] = valid_corrects
79 |
80 | # Shuffle valid results
81 | indices = np.random.permutation(len(sample['predict_sqls'])).tolist()
82 | sample['predict_sqls'] = [sample['predict_sqls'][i] for i in indices]
83 | sample['pred_results'] = [sample['pred_results'][i] for i in indices]
84 | sample['is_execution_corrects'] = [sample['is_execution_corrects'][i] for i in indices]
85 |
86 | # Select a random number of candidates
87 | n_candidates = np.random.randint(2, 6)
88 | sample['predict_sqls'] = sample['predict_sqls'][:n_candidates]
89 | sample['pred_results'] = sample['pred_results'][:n_candidates]
90 | sample['is_execution_corrects'] = sample['is_execution_corrects'][:n_candidates]
91 | sample['candidate_sqls'] = sample['predict_sqls']
92 | sample['candidate_pred_results'] = sample['pred_results']
93 |
94 | # Generate prompt and answers
95 | prompt, answers = selection_agent.generate(sample)
96 |
97 | dpo_data = {
98 | 'db_path': sample['db_path'],
99 | 'db_id': sample['db_id'],
100 | 'question': sample['question'],
101 | 'sql': sample['sql'],
102 | 'true_result': str(sample['true_result']).strip(),
103 | 'predict_sqls': sample['predict_sqls'],
104 | 'pred_results': [str(x).strip() for x in sample['pred_results']],
105 | 'is_execution_corrects': sample['is_execution_corrects'],
106 | 'reward_data': []
107 | }
108 |
109 | for answer in answers:
110 | answer_index = selection_agent.extract_answer_index(answer)
111 |
112 | if answer_index == -1 and sum(sample['is_execution_corrects']) > 0:
113 | reward = 0
114 | elif answer_index == -1 and sum(sample['is_execution_corrects']) == 0:
115 | reward = 1
116 | elif answer_index > len(sample['is_execution_corrects']):
117 | reward = 0
118 | elif answer_index > 0:
119 | reward = int(sample['is_execution_corrects'][answer_index - 1])
120 | else:
121 | reward = -2
122 |
123 | dpo_data['reward_data'].append({
124 | 'prompt': prompt,
125 | 'completion': answer,
126 | 'reward': reward
127 | })
128 |
129 | return dpo_data
130 |
131 | if __name__ == "__main__":
132 | num_processes = min(32, cpu_count()) # Use up to 32 processes
133 |
134 | # Track progress and write every 50 samples
135 | processed_count = 0
136 | with Pool(num_processes) as pool, open(args.progress_file, 'a', encoding='utf-8') as f:
137 | for dpo_data in tqdm(pool.imap_unordered(build_dpo_data, filtered_preds), total=len(filtered_preds)):
138 | f.write(json.dumps(dpo_data, ensure_ascii=False) + "\n")
139 | processed_count += 1
140 |
141 | # Save every 50 samples
142 | if processed_count % 50 == 0:
143 | f.flush()
144 |
--------------------------------------------------------------------------------
/alignment-handbook/tests/test_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 |
17 | import pytest
18 | from datasets import Dataset
19 |
20 | from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
21 |
22 |
23 | class GetDatasetsTest(unittest.TestCase):
24 | """Each of these test datasets has 100 examples"""
25 |
26 | def test_loading_data_args(self):
27 | dataset_mixer = {
28 | "HuggingFaceH4/testing_alpaca_small": 0.5,
29 | "HuggingFaceH4/testing_self_instruct_small": 0.3,
30 | "HuggingFaceH4/testing_codealpaca_small": 0.2,
31 | }
32 | data_args = DataArguments(dataset_mixer=dataset_mixer)
33 | datasets = get_datasets(data_args)
34 | self.assertEqual(len(datasets["train"]), 100)
35 | self.assertEqual(len(datasets["test"]), 300)
36 |
37 | def test_loading_data_dict(self):
38 | dataset_mixer = {
39 | "HuggingFaceH4/testing_alpaca_small": 0.5,
40 | "HuggingFaceH4/testing_self_instruct_small": 0.3,
41 | "HuggingFaceH4/testing_codealpaca_small": 0.2,
42 | }
43 | datasets = get_datasets(dataset_mixer)
44 | self.assertEqual(len(datasets["train"]), 100)
45 | self.assertEqual(len(datasets["test"]), 300)
46 |
47 | def test_loading_with_unit_fractions(self):
48 | dataset_mixer = {
49 | "HuggingFaceH4/testing_alpaca_small": 1.0,
50 | "HuggingFaceH4/testing_self_instruct_small": 1.0,
51 | "HuggingFaceH4/testing_codealpaca_small": 1.0,
52 | }
53 | datasets = get_datasets(dataset_mixer)
54 | self.assertEqual(len(datasets["train"]), 300)
55 | self.assertEqual(len(datasets["test"]), 300)
56 |
57 | def test_loading_with_fractions_greater_than_unity(self):
58 | dataset_mixer = {
59 | "HuggingFaceH4/testing_alpaca_small": 0.7,
60 | "HuggingFaceH4/testing_self_instruct_small": 0.4,
61 | }
62 | datasets = get_datasets(dataset_mixer)
63 | self.assertEqual(len(datasets["train"]), 70 + 40)
64 | self.assertEqual(len(datasets["test"]), 200)
65 |
66 | def test_loading_fails_with_negative_fractions(self):
67 | dataset_mixer = {
68 | "HuggingFaceH4/testing_alpaca_small": 0.7,
69 | "HuggingFaceH4/testing_self_instruct_small": -0.3,
70 | }
71 | with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
72 | get_datasets(dataset_mixer)
73 |
74 | def test_loading_single_split_with_unit_fractions(self):
75 | dataset_mixer = {
76 | "HuggingFaceH4/testing_alpaca_small": 1.0,
77 | }
78 | datasets = get_datasets(dataset_mixer, splits=["test"])
79 | self.assertEqual(len(datasets["test"]), 100)
80 | self.assertRaises(KeyError, lambda: datasets["train"])
81 |
82 |
83 | class ApplyChatTemplateTest(unittest.TestCase):
84 | def setUp(self):
85 | model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
86 | data_args = DataArguments()
87 | self.tokenizer = get_tokenizer(model_args, data_args)
88 | self.dataset = Dataset.from_dict(
89 | {
90 | "prompt": ["Hello!"],
91 | "messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
92 | "chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
93 | "rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]],
94 | }
95 | )
96 |
97 | def test_sft(self):
98 | dataset = self.dataset.map(
99 | apply_chat_template,
100 | fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
101 | remove_columns=self.dataset.column_names,
102 | )
103 | self.assertDictEqual(
104 | dataset[0],
105 | {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n"},
106 | )
107 |
108 | def test_generation(self):
109 | # Remove last turn from messages
110 | dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
111 | dataset = dataset.map(
112 | apply_chat_template,
113 | fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
114 | remove_columns=self.dataset.column_names,
115 | )
116 | self.assertDictEqual(
117 | dataset[0],
118 | {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n"},
119 | )
120 |
121 | def test_rm(self):
122 | dataset = self.dataset.map(
123 | apply_chat_template,
124 | fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
125 | remove_columns=self.dataset.column_names,
126 | )
127 | self.assertDictEqual(
128 | dataset[0],
129 | {
130 | "text_chosen": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n",
131 | "text_rejected": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nHola!\n",
132 | },
133 | )
134 |
135 | def test_dpo(self):
136 | dataset = self.dataset.map(
137 | apply_chat_template,
138 | fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
139 | remove_columns=self.dataset.column_names,
140 | )
141 | self.assertDictEqual(
142 | dataset[0],
143 | {
144 | "text_prompt": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n",
145 | "text_chosen": "Bonjour!\n",
146 | "text_rejected": "Hola!\n",
147 | },
148 | )
149 |
--------------------------------------------------------------------------------
/validator_data/generate_fixed_sql_using_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sqlite3
3 | import os
4 | import multiprocessing.pool
5 | import functools
6 | from tqdm import tqdm
7 | import pandas as pd
8 | from utils import get_columns_in_select_clause
9 |
10 | def timeout(max_timeout):
11 | """Timeout decorator, parameter in seconds."""
12 | def timeout_decorator(item):
13 | """Wrap the original function."""
14 | @functools.wraps(item)
15 | def func_wrapper(*args, **kwargs):
16 | """Closure for function."""
17 | pool = multiprocessing.pool.ThreadPool(processes=1)
18 | async_result = pool.apply_async(item, args, kwargs)
19 | # raises a TimeoutError if execution exceeds max_timeout
20 | return async_result.get(max_timeout)
21 | return func_wrapper
22 | return timeout_decorator
23 |
24 | @timeout(30)
25 | def _execute_sql_with_timeout(db_path, action):
26 | conn = sqlite3.connect(db_path)
27 | conn.text_factory = lambda b: b.decode(errors="ignore")
28 | actions = action.split(";")
29 | actions = [x for x in actions if len(x.strip()) > 0]
30 | if len(actions) == 0:
31 | return "no SQL query executed.", True
32 | cursor = conn.cursor()
33 | for action in actions:
34 | # action = action.lower()
35 | try:
36 | cursor.execute(action)
37 | response = cursor.fetchall()
38 | has_error = False
39 | except Exception as error:
40 | # If the SQL query is invalid, return error message from sqlite
41 | response = str(error)
42 | has_error = True
43 | cursor.close()
44 | break
45 | cursor.close()
46 | conn.close()
47 | return response, has_error
48 |
49 | def _execute_sql(db_path, sql_query):
50 | try:
51 | pred_result, has_error = _execute_sql_with_timeout(db_path, sql_query)
52 | except:
53 | pred_result = "The query takes too much time."
54 | has_error = True
55 | return pred_result, has_error
56 |
57 | def _make_str_response(response, has_error):
58 | if has_error:
59 | return str(response)
60 | else:
61 | df = pd.DataFrame(response)
62 | return str(df)
63 |
64 | # PROMPT = open('./few_shot_prompt_fix.txt').read() + """=========
65 | # {schema}
66 |
67 | # Matched contents are written in this format table.column (some values can be found in that column)
68 | # {matched_content}
69 |
70 | # Question: {question}
71 |
72 | # SQL query: {sql_query}
73 |
74 | # Execution response [written in pandas format]:
75 | # {execution_response}
76 |
77 | # Feedback:{feedback}
78 |
79 | # FIXED SQL:"""
80 |
81 | PROMPT = open('./few_shot_prompt_fix.txt').read().strip() + """
82 | =========
83 | {schema}
84 |
85 | Matched contents are written in this format table.column (some values can be found in that column)
86 | {matched_content}
87 |
88 | Question: {question}
89 |
90 | SQL query: {sql_query}
91 |
92 | Feedback:{feedback}
93 |
94 | FIXED SQL:"""
95 |
96 |
97 | from openai import OpenAI
98 |
99 | client = OpenAI(
100 | api_key='no-key',
101 | base_url='http://localhost:8000/v1'
102 | )
103 |
104 | # def get_answer(messages):
105 | # response = client.chat.completions.create(
106 | # model='codeS',
107 | # messages=messages,
108 | # max_tokens=2048,
109 | # temperature=0.0,
110 | # # eos_token_id=self.tokenizer.convert_tokens_to_ids(['<|end|>'])
111 | # )
112 | # response = response.choices[0].message.content.strip()
113 | # return response
114 |
115 | # def get_answer(messages):
116 | # response = client.completions.create(
117 | # model='meta-llama/Meta-Llama-3.1-8B-Instruct/',
118 | # prompt=messages[0]['content'],
119 | # max_tokens=256,
120 | # temperature=0.0,
121 | # stop=['=========']
122 | # # eos_token_id=self.tokenizer.convert_tokens_to_ids(['<|end|>'])
123 | # )
124 | # response = response.choices[0].text
125 | # return response
126 |
127 | def get_answer(messages):
128 | import requests
129 | response = requests.post("http://localhost:8000/v1/completions",
130 | json={
131 | "model": "meta-llama/Meta-Llama-3.1-8B-Instruct/",
132 | "prompt": messages[0]['content'],
133 | "max_tokens": 256,
134 | "use_beam_search": True,
135 | "n": 4,
136 | "temperature": 0,
137 | "stop": ["========="]
138 | }).json()
139 | return response["choices"][0]["text"]
140 |
141 | data = json.load(open('./bird_validator_select.json'))
142 | output_file = './bird_fixed_sql.json'
143 |
144 | # data = json.load(open('../temp/codes/temp/codes/eval_codes-1b.json'))
145 | # output_file = 'bird_dev_validator_select.json'
146 |
147 | for isample in tqdm(range(0, len(data)), total=len(data)):
148 | sample = data[isample]
149 |
150 | sql = sample['predict_sql']
151 | is_correct = sample['is_correct']
152 | if sample['validator_select'] is None or "Conclude: correct" in sample['validator_select']:
153 | continue
154 |
155 | prompt = PROMPT.format(
156 | schema=sample['schema_sequence'],
157 | matched_content=sample['content_sequence'],
158 | question=sample['text'],
159 | sql_query=sql,
160 | # execution_response=sample['pred_result'],
161 | feedback=sample['validator_select']
162 | )
163 | # print(prompt)
164 | answer = get_answer([{"role": "user", "content": prompt}])
165 |
166 | execution_result = _execute_sql("../" + sample['db_path'], answer)
167 |
168 | print("-"*20)
169 | print(answer)
170 | # break
171 | sample['fixed_sql'] = answer
172 | sample['fixed_pred_result'] = _make_str_response(*execution_result)
173 |
174 | json.dump(data[:isample+1], open(output_file, 'w+'), ensure_ascii=False, indent=4)
175 | json.dump(data[:isample+1], open(output_file, 'w+'), ensure_ascii=False, indent=4)
176 |
177 | bird_results_dict = dict()
178 | for idx, sample in enumerate(data):
179 | if 'fixed_sql' in sample:
180 | predicted_sql = sample['fixed_sql']
181 | else:
182 | predicted_sql = sample['predict_sql']
183 | bird_results_dict[idx] = predicted_sql + "\t----- bird -----\t" + sample["db_id"]
184 | with open("predict_dev.json", "w", encoding = 'utf-8') as f:
185 | f.write(json.dumps(bird_results_dict, indent = 2, ensure_ascii = False))
186 |
--------------------------------------------------------------------------------