├── README.md ├── scripts ├── eval_cosql_scprompt.sh ├── eval_geoquery_scprompt.sh ├── eval_spider_scprompt.sh ├── train_cosql_scprompt.sh ├── train_geoquery_scprompt.sh └── train_spider_scprompt.sh └── src ├── datasets ├── cosql │ └── cosql.py ├── geoquery │ └── geoquery.py └── spider │ └── spider.py ├── metrics ├── cosql │ ├── cosql.py │ ├── spider_exact_match.py │ └── spider_test_suite.py └── spider │ ├── spider.py │ ├── spider_exact_match.py │ └── spider_test_suite.py ├── run.py └── utils ├── PT_wrapper.py ├── __init__.py ├── args.py ├── bridge_content_encoder.py ├── cosql.py ├── dataset.py ├── dataset_loader.py ├── decode_wrapper.py ├── evaluation.py ├── geoquery.py ├── get_tables.py ├── process_sql.py ├── spider.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # SC-prompt 2 | ## Introduction 3 | This repository contains the code for the paper "Few-shot Text-to-SQL Translation using Structure and Content Prompt Learning". In this paper, we propose SC-Prompt, a novel divide-and-conquer strategy for effectively supporting Text-to-SQL translation in the few-shot scenario. 4 | 5 | ## Setup 6 | ```sh 7 | git clone git@github.com:ruc-datalab/SC-prompt.git 8 | cd SC-prompt 9 | mkdir -p -m 777 experimental_outputs 10 | mkdir -p -m 777 transformers_cache 11 | cd experimental_outputs 12 | mkdir -p -m 777 spider 13 | mkdir -p -m 777 cosql 14 | mkdir -p -m 777 geoquery 15 | cd .. 16 | ``` 17 | 18 | ## Dataset Download 19 | 20 | - [Spider](https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0): Put it under `src/datasets/spider`. 21 | - [Cosql](https://drive.google.com/uc?export=download&id=14x6lsWqlu6gR-aYxa6cemslDN3qT3zxP): Put it under `src/datasets/cosql`. 22 | - [Geoquery](https://drive.google.com/file/d/1hP4gpExG1EJCN3a1vOyK4XR4mTSFi7Q1/view?usp=share_link): Put it under `src/datasets/geoquery`. 23 | 24 | ## Code Structure 25 | 26 | ```sh 27 | |-- experimental_outputs # save the fine-tuned models and evaluation results 28 | |-- scripts # the train/inference script 29 | |-- src 30 | |-- datasets # the class to preprocess the dataset 31 | |-- metrics # the class to evaluate the prediction results 32 | |-- utils # main code 33 | |-- run.py # the class to train/inference the few-shot text-to-sql model 34 | ``` 35 | 36 | ## Environment 37 | Our constrained decoding method is based on the parser provided by [Picard](https://arxiv.org/abs/2109.05093). Please use the Docker image provided by the official [repository](https://github.com/ServiceNow/picard) to build the container. 38 | 39 | ```sh 40 | docker run -itd --gpus '"device="' --rm --user 13011:13011 --mount type=bind,source=/transformers_cache,target=/transformers_cache --mount type=bind,source=/scripts,target=/app/scripts --mount type=bind,source=/experimental_outputs,target=/app/experimental_outputs --mount type=bind,source=/src,target=/app/src tscholak/text-to-sql-eval:6a252386bed6d4233f0f13f4562d8ae8608e7445 41 | ``` 42 | You should set `` and ``. 43 | 44 | ## Quick Inference 45 | 46 | Download the fine-tuned model and put it under the corresponding folder. 47 | 48 | | Dataset | #Train | Model | Folder | 49 | |-------|--------|--------|---------| 50 | | Spider | 0.05 (350) | [link](https://drive.google.com/drive/folders/1b-16LFsnVMC5U2JxRew9nKtdOIhVr46j?usp=share_link) | experimental_outputs/spider/ | 51 | | Spider | 0.1 (700) | [link](https://drive.google.com/drive/folders/16qcI-zcahpB-Y6BUyizLmt3-EMP8_sM7?usp=share_link) | experimental_outputs/spider/ | 52 | | CoSQL | 0.05 (475) | [link](https://drive.google.com/drive/folders/1DxNdW5oBMQgYm7GE_VfvT9lFrJLcCpLs?usp=share_link) | experimental_outputs/cosql/ | 53 | | CoSQL | 0.1 (950) | [link](https://drive.google.com/drive/folders/1MhbsPsyhD0RTVYFJ7jiqy8zxxUo2_4kp?usp=share_link) | experimental_outputs/cosql/ | 54 | | Geoquery | 1. (536) | [link](https://drive.google.com/drive/folders/1Z-akKlTFhiNGdT23kmpU8VFQ3L5XvOgD?usp=share_link) | experimental_outputs/geoquery/ | 55 | 56 | Use the scripts to inference. 57 | ```sh 58 | # Inference on spider 59 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval_spider_scprompt.sh 0.1 60 | # Inference on cosql 61 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval_cosql_scprompt.sh 0.1 62 | # Inference on geoquery 63 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval_geoquery_scprompt.sh 1. 64 | ``` 65 | - The second argument refers to the proportion of using the official training set. 66 | 67 | ## Train from scrach 68 | ```sh 69 | # Train on spider 70 | CUDA_VISIBLE_DEVICES=0 bash scripts/train_spider_scprompt.sh 0.1 71 | # Train on cosql 72 | CUDA_VISIBLE_DEVICES=0 bash scripts/train_cosql_scprompt.sh 0.1 73 | # Train on geoquery 74 | CUDA_VISIBLE_DEVICES=0 bash scripts/train_geoquery_scprompt.sh 1. 75 | ``` 76 | - The second argument refers to the proportion of using the official training set. 77 | 78 | The best model will be automatically saved at `experimental_outputs/`. Please note that training does not use the fine-grained constrained decoding strategy, which is only necessary for evaluation. Please refer to `Quick Inference`to evaluate the fine-tuned model. 79 | 80 | -------------------------------------------------------------------------------- /scripts/eval_cosql_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. predict SQL structure 2 | python src/run.py \ 3 | --run_name t5-large \ 4 | --model_name_or_path t5-large \ 5 | --dataset cosql \ 6 | --source_prefix "question: " \ 7 | --schema_serialization_type verbose \ 8 | --schema_serialization_randomized false \ 9 | --schema_serialization_with_db_id true \ 10 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 11 | --schema_serialization_with_db_content true \ 12 | --normalize_query true \ 13 | --target_with_db_id false \ 14 | --metric_config both \ 15 | --output_dir experimental_outputs/cosql/ \ 16 | --cache_dir transformers_cache \ 17 | --do_train false \ 18 | --do_eval true \ 19 | --fp16 false \ 20 | --per_device_eval_batch_size 2 \ 21 | --label_smoothing_factor 0.0 \ 22 | --learning_rate 5e-5 \ 23 | --adafactor true \ 24 | --adam_eps 1e-6 \ 25 | --lr_scheduler_type constant \ 26 | --warmup_ratio 0.0 \ 27 | --warmup_steps 0 \ 28 | --seed 1 \ 29 | --logging_strategy steps \ 30 | --logging_steps 4 \ 31 | --metric_for_best_model exact_match \ 32 | --greater_is_better true \ 33 | --save_strategy steps \ 34 | --evaluation_strategy steps \ 35 | --predict_with_generate true \ 36 | --num_beams 8 \ 37 | --num_beam_groups 1 \ 38 | --diversity_penalty 0.0 \ 39 | --max_val_samples 1300 \ 40 | --use_constrained_decoding false \ 41 | --use_decomposition true \ 42 | --overwrite_output_dir true \ 43 | --stage structure \ 44 | --training_method PFT \ 45 | --overwrite_cache true \ 46 | --train_samples_ratio $1 47 | 48 | # 2. predict SQL content 49 | python src/run.py \ 50 | --run_name t5-large \ 51 | --model_name_or_path t5-large \ 52 | --dataset cosql \ 53 | --source_prefix "question: " \ 54 | --schema_serialization_type verbose \ 55 | --schema_serialization_randomized false \ 56 | --schema_serialization_with_db_id true \ 57 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 58 | --schema_serialization_with_db_content true \ 59 | --normalize_query true \ 60 | --target_with_db_id false \ 61 | --metric_config both \ 62 | --output_dir experimental_outputs/cosql/ \ 63 | --cache_dir transformers_cache \ 64 | --do_train false \ 65 | --do_eval true \ 66 | --fp16 false \ 67 | --per_device_eval_batch_size 2 \ 68 | --label_smoothing_factor 0.0 \ 69 | --learning_rate 5e-5 \ 70 | --adafactor true \ 71 | --adam_eps 1e-6 \ 72 | --lr_scheduler_type constant \ 73 | --warmup_ratio 0.0 \ 74 | --warmup_steps 0 \ 75 | --seed 1 \ 76 | --logging_strategy steps \ 77 | --logging_steps 4 \ 78 | --metric_for_best_model exact_match \ 79 | --greater_is_better true \ 80 | --save_strategy steps \ 81 | --evaluation_strategy steps \ 82 | --predict_with_generate true \ 83 | --num_beams 4 \ 84 | --num_beam_groups 1 \ 85 | --diversity_penalty 0.0 \ 86 | --max_val_samples 1300 \ 87 | --use_constrained_decoding false \ 88 | --use_decomposition true \ 89 | --overwrite_output_dir true \ 90 | --stage content \ 91 | --training_method PFT \ 92 | --overwrite_cache true \ 93 | --train_samples_ratio $1 -------------------------------------------------------------------------------- /scripts/eval_geoquery_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. predict SQL structure 2 | python src/run.py \ 3 | --run_name t5-large \ 4 | --model_name_or_path t5-large \ 5 | --dataset geoquery \ 6 | --source_prefix "question: " \ 7 | --schema_serialization_type verbose \ 8 | --schema_serialization_randomized false \ 9 | --schema_serialization_with_db_id true \ 10 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 11 | --schema_serialization_with_db_content true \ 12 | --normalize_query true \ 13 | --target_with_db_id false \ 14 | --metric_config both \ 15 | --output_dir experimental_outputs/geoquery/ \ 16 | --cache_dir transformers_cache \ 17 | --do_train false \ 18 | --do_eval true \ 19 | --fp16 false \ 20 | --per_device_eval_batch_size 2 \ 21 | --label_smoothing_factor 0.0 \ 22 | --learning_rate 5e-5 \ 23 | --adafactor true \ 24 | --adam_eps 1e-6 \ 25 | --lr_scheduler_type constant \ 26 | --warmup_ratio 0.0 \ 27 | --warmup_steps 0 \ 28 | --seed 1 \ 29 | --logging_strategy steps \ 30 | --logging_steps 4 \ 31 | --metric_for_best_model exact_match \ 32 | --greater_is_better true \ 33 | --save_strategy steps \ 34 | --evaluation_strategy steps \ 35 | --predict_with_generate true \ 36 | --num_beams 8 \ 37 | --num_beam_groups 1 \ 38 | --diversity_penalty 0.0 \ 39 | --max_val_samples 182 \ 40 | --use_constrained_decoding false \ 41 | --use_decomposition true \ 42 | --overwrite_output_dir true \ 43 | --stage structure \ 44 | --training_method PFT \ 45 | --overwrite_cache true \ 46 | --train_samples_ratio $1 47 | 48 | # 2. predict SQL content 49 | python src/run.py \ 50 | --run_name t5-large \ 51 | --model_name_or_path t5-large \ 52 | --dataset geoquery \ 53 | --source_prefix "question: " \ 54 | --schema_serialization_type verbose \ 55 | --schema_serialization_randomized false \ 56 | --schema_serialization_with_db_id true \ 57 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 58 | --schema_serialization_with_db_content true \ 59 | --normalize_query true \ 60 | --target_with_db_id false \ 61 | --metric_config both \ 62 | --output_dir experimental_outputs/geoquery/ \ 63 | --cache_dir transformers_cache \ 64 | --do_train false \ 65 | --do_eval true \ 66 | --fp16 false \ 67 | --per_device_eval_batch_size 2 \ 68 | --label_smoothing_factor 0.0 \ 69 | --learning_rate 5e-5 \ 70 | --adafactor true \ 71 | --adam_eps 1e-6 \ 72 | --lr_scheduler_type constant \ 73 | --warmup_ratio 0.0 \ 74 | --warmup_steps 0 \ 75 | --seed 1 \ 76 | --logging_strategy steps \ 77 | --logging_steps 4 \ 78 | --metric_for_best_model exact_match \ 79 | --greater_is_better true \ 80 | --save_strategy steps \ 81 | --evaluation_strategy steps \ 82 | --predict_with_generate true \ 83 | --num_beams 4 \ 84 | --num_beam_groups 1 \ 85 | --diversity_penalty 0.0 \ 86 | --max_val_samples 182 \ 87 | --use_constrained_decoding false \ 88 | --use_decomposition true \ 89 | --overwrite_output_dir true \ 90 | --stage content \ 91 | --training_method PFT \ 92 | --overwrite_cache true \ 93 | --train_samples_ratio $1 -------------------------------------------------------------------------------- /scripts/eval_spider_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. predict SQL structure 2 | python src/run.py \ 3 | --run_name t5-large \ 4 | --model_name_or_path t5-large \ 5 | --dataset spider \ 6 | --source_prefix "question: " \ 7 | --schema_serialization_type verbose \ 8 | --schema_serialization_randomized false \ 9 | --schema_serialization_with_db_id true \ 10 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 11 | --schema_serialization_with_db_content true \ 12 | --normalize_query true \ 13 | --target_with_db_id false \ 14 | --metric_config both \ 15 | --output_dir experimental_outputs/spider/ \ 16 | --cache_dir transformers_cache \ 17 | --do_train false \ 18 | --do_eval true \ 19 | --fp16 false \ 20 | --per_device_eval_batch_size 2 \ 21 | --label_smoothing_factor 0.0 \ 22 | --learning_rate 5e-5 \ 23 | --adafactor true \ 24 | --adam_eps 1e-6 \ 25 | --lr_scheduler_type constant \ 26 | --warmup_ratio 0.0 \ 27 | --warmup_steps 0 \ 28 | --seed 1 \ 29 | --logging_strategy steps \ 30 | --logging_steps 4 \ 31 | --metric_for_best_model exact_match \ 32 | --greater_is_better true \ 33 | --save_strategy steps \ 34 | --evaluation_strategy steps \ 35 | --predict_with_generate true \ 36 | --num_beams 8 \ 37 | --num_beam_groups 1 \ 38 | --diversity_penalty 0.0 \ 39 | --max_val_samples 1034 \ 40 | --use_constrained_decoding true \ 41 | --use_decomposition true \ 42 | --overwrite_output_dir true \ 43 | --stage structure \ 44 | --training_method PFT \ 45 | --overwrite_cache true \ 46 | --train_samples_ratio $1 47 | 48 | # 2. predict SQL content 49 | python src/run.py \ 50 | --run_name t5-large \ 51 | --model_name_or_path t5-large \ 52 | --dataset spider \ 53 | --source_prefix "question: " \ 54 | --schema_serialization_type verbose \ 55 | --schema_serialization_randomized false \ 56 | --schema_serialization_with_db_id true \ 57 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 58 | --schema_serialization_with_db_content true \ 59 | --normalize_query true \ 60 | --target_with_db_id false \ 61 | --metric_config both \ 62 | --output_dir experimental_outputs/spider/ \ 63 | --cache_dir transformers_cache \ 64 | --do_train false \ 65 | --do_eval true \ 66 | --fp16 false \ 67 | --per_device_eval_batch_size 2 \ 68 | --label_smoothing_factor 0.0 \ 69 | --learning_rate 5e-5 \ 70 | --adafactor true \ 71 | --adam_eps 1e-6 \ 72 | --lr_scheduler_type constant \ 73 | --warmup_ratio 0.0 \ 74 | --warmup_steps 0 \ 75 | --seed 1 \ 76 | --logging_strategy steps \ 77 | --logging_steps 4 \ 78 | --metric_for_best_model exact_match \ 79 | --greater_is_better true \ 80 | --save_strategy steps \ 81 | --evaluation_strategy steps \ 82 | --predict_with_generate true \ 83 | --num_beams 4 \ 84 | --num_beam_groups 1 \ 85 | --diversity_penalty 0.0 \ 86 | --max_val_samples 1034 \ 87 | --use_constrained_decoding true \ 88 | --use_decomposition true \ 89 | --overwrite_output_dir true \ 90 | --stage content \ 91 | --training_method PFT \ 92 | --overwrite_cache true \ 93 | --train_samples_ratio $1 94 | -------------------------------------------------------------------------------- /scripts/train_cosql_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. structure stage 2 | # Note: To make the training process more stable, we first freeze the model to train learnable vectors. 3 | python src/run.py \ 4 | --run_name t5-large \ 5 | --model_name_or_path t5-large \ 6 | --dataset cosql \ 7 | --source_prefix "question: " \ 8 | --schema_serialization_type verbose \ 9 | --schema_serialization_randomized false \ 10 | --schema_serialization_with_db_id true \ 11 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 12 | --schema_serialization_with_db_content true \ 13 | --normalize_query true \ 14 | --target_with_db_id false \ 15 | --metric_config both \ 16 | --output_dir experimental_outputs/cosql/ \ 17 | --cache_dir transformers_cache \ 18 | --do_train true \ 19 | --do_eval false \ 20 | --fp16 false \ 21 | --num_train_epochs 100 \ 22 | --per_device_train_batch_size 2 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 16 \ 25 | --label_smoothing_factor 0.0 \ 26 | --learning_rate 0.1 \ 27 | --adafactor true \ 28 | --adam_eps 1e-6 \ 29 | --lr_scheduler_type constant \ 30 | --warmup_ratio 0.0 \ 31 | --warmup_steps 0 \ 32 | --seed 1 \ 33 | --logging_strategy steps \ 34 | --logging_steps 4 \ 35 | --metric_for_best_model exact_match \ 36 | --greater_is_better true \ 37 | --save_strategy steps \ 38 | --evaluation_strategy steps \ 39 | --predict_with_generate true \ 40 | --num_beams 1 \ 41 | --num_beam_groups 1 \ 42 | --use_constrained_decoding false \ 43 | --use_decomposition true \ 44 | --overwrite_output_dir true \ 45 | --stage structure \ 46 | --training_method PT \ 47 | --overwrite_cache true \ 48 | --train_samples_ratio $1 49 | 50 | python src/run.py \ 51 | --run_name t5-large \ 52 | --model_name_or_path t5-large \ 53 | --dataset cosql \ 54 | --source_prefix "question: " \ 55 | --schema_serialization_type verbose \ 56 | --schema_serialization_randomized false \ 57 | --schema_serialization_with_db_id true \ 58 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 59 | --schema_serialization_with_db_content true \ 60 | --normalize_query true \ 61 | --target_with_db_id false \ 62 | --metric_config both \ 63 | --output_dir experimental_outputs/cosql/ \ 64 | --cache_dir transformers_cache \ 65 | --do_train true \ 66 | --do_eval false \ 67 | --fp16 false \ 68 | --num_train_epochs 150 \ 69 | --per_device_train_batch_size 2 \ 70 | --per_device_eval_batch_size 2 \ 71 | --gradient_accumulation_steps 16 \ 72 | --label_smoothing_factor 0.0 \ 73 | --learning_rate 5e-5 \ 74 | --adafactor true \ 75 | --adam_eps 1e-6 \ 76 | --lr_scheduler_type constant \ 77 | --warmup_ratio 0.0 \ 78 | --warmup_steps 0 \ 79 | --seed 1 \ 80 | --logging_strategy steps \ 81 | --logging_steps 4 \ 82 | --metric_for_best_model exact_match \ 83 | --greater_is_better true \ 84 | --save_strategy steps \ 85 | --evaluation_strategy steps \ 86 | --predict_with_generate true \ 87 | --num_beams 1 \ 88 | --num_beam_groups 1 \ 89 | --use_constrained_decoding false \ 90 | --use_decomposition true \ 91 | --overwrite_output_dir true \ 92 | --stage structure \ 93 | --training_method PFT \ 94 | --overwrite_cache true \ 95 | --train_samples_ratio $1 96 | 97 | # content stage 98 | python src/run.py \ 99 | --run_name t5-large \ 100 | --model_name_or_path t5-large \ 101 | --dataset cosql \ 102 | --source_prefix "question: " \ 103 | --schema_serialization_type verbose \ 104 | --schema_serialization_randomized false \ 105 | --schema_serialization_with_db_id true \ 106 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 107 | --schema_serialization_with_db_content true \ 108 | --normalize_query true \ 109 | --target_with_db_id false \ 110 | --metric_config both \ 111 | --output_dir experimental_outputs/cosql/ \ 112 | --cache_dir transformers_cache \ 113 | --do_train true \ 114 | --do_eval false \ 115 | --fp16 false \ 116 | --num_train_epochs 900 \ 117 | --per_device_train_batch_size 2 \ 118 | --per_device_eval_batch_size 4 \ 119 | --gradient_accumulation_steps 16 \ 120 | --label_smoothing_factor 0.0 \ 121 | --learning_rate 0.1 \ 122 | --adafactor true \ 123 | --adam_eps 1e-6 \ 124 | --lr_scheduler_type constant \ 125 | --warmup_ratio 0.0 \ 126 | --warmup_steps 0 \ 127 | --seed 1 \ 128 | --logging_strategy steps \ 129 | --logging_steps 4 \ 130 | --metric_for_best_model exact_match \ 131 | --greater_is_better true \ 132 | --save_strategy steps \ 133 | --evaluation_strategy steps \ 134 | --predict_with_generate true \ 135 | --num_beams 1 \ 136 | --num_beam_groups 1 \ 137 | --use_constrained_decoding false \ 138 | --use_decomposition true \ 139 | --overwrite_output_dir true \ 140 | --stage content \ 141 | --training_method PT \ 142 | --overwrite_cache true \ 143 | --train_samples_ratio $1 144 | 145 | python src/run.py \ 146 | --run_name t5-large \ 147 | --model_name_or_path t5-large \ 148 | --dataset cosql \ 149 | --source_prefix "question: " \ 150 | --schema_serialization_type verbose \ 151 | --schema_serialization_randomized false \ 152 | --schema_serialization_with_db_id true \ 153 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 154 | --schema_serialization_with_db_content true \ 155 | --normalize_query true \ 156 | --target_with_db_id false \ 157 | --metric_config both \ 158 | --output_dir experimental_outputs/cosql/ \ 159 | --cache_dir transformers_cache \ 160 | --do_train true \ 161 | --do_eval false \ 162 | --fp16 false \ 163 | --num_train_epochs 100 \ 164 | --per_device_train_batch_size 2 \ 165 | --per_device_eval_batch_size 4 \ 166 | --gradient_accumulation_steps 16 \ 167 | --label_smoothing_factor 0.0 \ 168 | --learning_rate 5e-5 \ 169 | --adafactor true \ 170 | --adam_eps 1e-6 \ 171 | --lr_scheduler_type constant \ 172 | --warmup_ratio 0.0 \ 173 | --warmup_steps 0 \ 174 | --seed 1 \ 175 | --logging_strategy steps \ 176 | --logging_steps 4 \ 177 | --metric_for_best_model exact_match \ 178 | --greater_is_better true \ 179 | --save_strategy steps \ 180 | --evaluation_strategy steps \ 181 | --predict_with_generate true \ 182 | --num_beams 1 \ 183 | --num_beam_groups 1 \ 184 | --use_constrained_decoding false \ 185 | --use_decomposition true \ 186 | --overwrite_output_dir true \ 187 | --stage content \ 188 | --training_method PFT \ 189 | --overwrite_cache true \ 190 | --train_samples_ratio $1 191 | 192 | -------------------------------------------------------------------------------- /scripts/train_geoquery_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. structure stage 2 | # Note: To make the training process more stable, we first freeze the model to train learnable vectors. 3 | 4 | python src/run.py \ 5 | --run_name t5-large \ 6 | --model_name_or_path t5-large \ 7 | --dataset geoquery \ 8 | --source_prefix "question: " \ 9 | --schema_serialization_type verbose \ 10 | --schema_serialization_randomized false \ 11 | --schema_serialization_with_db_id true \ 12 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 13 | --schema_serialization_with_db_content true \ 14 | --normalize_query true \ 15 | --target_with_db_id false \ 16 | --metric_config both \ 17 | --output_dir experimental_outputs/geoquery/ \ 18 | --cache_dir transformers_cache \ 19 | --do_train true \ 20 | --do_eval false \ 21 | --fp16 false \ 22 | --num_train_epochs 10 \ 23 | --per_device_train_batch_size 2 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 16 \ 26 | --label_smoothing_factor 0.0 \ 27 | --learning_rate 0.1 \ 28 | --adafactor true \ 29 | --adam_eps 1e-6 \ 30 | --lr_scheduler_type constant \ 31 | --warmup_ratio 0.0 \ 32 | --warmup_steps 0 \ 33 | --seed 1 \ 34 | --logging_strategy steps \ 35 | --logging_steps 4 \ 36 | --metric_for_best_model exact_match \ 37 | --greater_is_better true \ 38 | --save_strategy steps \ 39 | --evaluation_strategy steps \ 40 | --predict_with_generate true \ 41 | --num_beams 1 \ 42 | --num_beam_groups 1 \ 43 | --use_constrained_decoding false \ 44 | --use_decomposition true \ 45 | --overwrite_output_dir true \ 46 | --stage structure \ 47 | --training_method PT \ 48 | --overwrite_cache true \ 49 | --train_samples_ratio $1 50 | 51 | python src/run.py \ 52 | --run_name t5-large \ 53 | --model_name_or_path t5-large \ 54 | --dataset geoquery \ 55 | --source_prefix "question: " \ 56 | --schema_serialization_type verbose \ 57 | --schema_serialization_randomized false \ 58 | --schema_serialization_with_db_id true \ 59 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 60 | --schema_serialization_with_db_content true \ 61 | --normalize_query true \ 62 | --target_with_db_id false \ 63 | --metric_config both \ 64 | --output_dir experimental_outputs/geoquery/ \ 65 | --cache_dir transformers_cache \ 66 | --do_train true \ 67 | --do_eval false \ 68 | --fp16 false \ 69 | --num_train_epochs 100 \ 70 | --per_device_train_batch_size 4 \ 71 | --per_device_eval_batch_size 4 \ 72 | --gradient_accumulation_steps 8 \ 73 | --label_smoothing_factor 0.0 \ 74 | --learning_rate 5e-5 \ 75 | --adafactor true \ 76 | --adam_eps 1e-6 \ 77 | --lr_scheduler_type constant \ 78 | --warmup_ratio 0.0 \ 79 | --warmup_steps 0 \ 80 | --seed 1 \ 81 | --logging_strategy steps \ 82 | --logging_steps 4 \ 83 | --metric_for_best_model exact_match \ 84 | --greater_is_better true \ 85 | --save_strategy steps \ 86 | --evaluation_strategy steps \ 87 | --predict_with_generate true \ 88 | --num_beams 1 \ 89 | --num_beam_groups 1 \ 90 | --use_constrained_decoding false \ 91 | --use_decomposition true \ 92 | --overwrite_output_dir true \ 93 | --stage structure \ 94 | --training_method PFT \ 95 | --overwrite_cache true \ 96 | --train_samples_ratio $1 97 | 98 | # 2. content stage 99 | # Note: To make the training process more stable, we first freeze the model to train learnable vectors. 100 | python src/run.py \ 101 | --run_name t5-large \ 102 | --model_name_or_path t5-large \ 103 | --dataset geoquery \ 104 | --source_prefix "question: " \ 105 | --schema_serialization_type verbose \ 106 | --schema_serialization_randomized false \ 107 | --schema_serialization_with_db_id true \ 108 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 109 | --schema_serialization_with_db_content true \ 110 | --normalize_query true \ 111 | --target_with_db_id false \ 112 | --metric_config both \ 113 | --output_dir experimental_outputs/geoquery/ \ 114 | --cache_dir transformers_cache \ 115 | --do_train true \ 116 | --do_eval false \ 117 | --fp16 false \ 118 | --num_train_epochs 900 \ 119 | --per_device_train_batch_size 2 \ 120 | --per_device_eval_batch_size 4 \ 121 | --gradient_accumulation_steps 16 \ 122 | --label_smoothing_factor 0.0 \ 123 | --learning_rate 0.1 \ 124 | --adafactor true \ 125 | --adam_eps 1e-6 \ 126 | --lr_scheduler_type constant \ 127 | --warmup_ratio 0.0 \ 128 | --warmup_steps 0 \ 129 | --seed 1 \ 130 | --logging_strategy steps \ 131 | --logging_steps 4 \ 132 | --metric_for_best_model exact_match \ 133 | --greater_is_better true \ 134 | --save_strategy steps \ 135 | --evaluation_strategy steps \ 136 | --predict_with_generate true \ 137 | --num_beams 1 \ 138 | --num_beam_groups 1 \ 139 | --use_constrained_decoding false \ 140 | --use_decomposition true \ 141 | --overwrite_output_dir true \ 142 | --stage content \ 143 | --training_method PT \ 144 | --overwrite_cache true \ 145 | --train_samples_ratio $1 146 | 147 | python src/run.py \ 148 | --run_name t5-large \ 149 | --model_name_or_path t5-large \ 150 | --dataset geoquery \ 151 | --source_prefix "question: " \ 152 | --schema_serialization_type verbose \ 153 | --schema_serialization_randomized false \ 154 | --schema_serialization_with_db_id true \ 155 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 156 | --schema_serialization_with_db_content true \ 157 | --normalize_query true \ 158 | --target_with_db_id false \ 159 | --metric_config both \ 160 | --output_dir experimental_outputs/geoquery/ \ 161 | --cache_dir transformers_cache \ 162 | --do_train true \ 163 | --do_eval false \ 164 | --fp16 false \ 165 | --num_train_epochs 150 \ 166 | --per_device_train_batch_size 2 \ 167 | --per_device_eval_batch_size 4 \ 168 | --gradient_accumulation_steps 16 \ 169 | --label_smoothing_factor 0.0 \ 170 | --learning_rate 5e-5 \ 171 | --adafactor true \ 172 | --adam_eps 1e-6 \ 173 | --lr_scheduler_type constant \ 174 | --warmup_ratio 0.0 \ 175 | --warmup_steps 0 \ 176 | --seed 1 \ 177 | --logging_strategy steps \ 178 | --logging_steps 4 \ 179 | --metric_for_best_model exact_match \ 180 | --greater_is_better true \ 181 | --save_strategy steps \ 182 | --evaluation_strategy steps \ 183 | --predict_with_generate true \ 184 | --num_beams 1 \ 185 | --num_beam_groups 1 \ 186 | --use_constrained_decoding false \ 187 | --use_decomposition true \ 188 | --overwrite_output_dir true \ 189 | --stage content \ 190 | --training_method PFT \ 191 | --overwrite_cache true \ 192 | --train_samples_ratio $1 -------------------------------------------------------------------------------- /scripts/train_spider_scprompt.sh: -------------------------------------------------------------------------------- 1 | # 1. structure stage 2 | # Note: To make the training process more stable, we first freeze the model to train learnable vectors. 3 | python src/run.py \ 4 | --run_name t5-large \ 5 | --model_name_or_path t5-large \ 6 | --dataset spider \ 7 | --source_prefix "question: " \ 8 | --schema_serialization_type verbose \ 9 | --schema_serialization_randomized false \ 10 | --schema_serialization_with_db_id true \ 11 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 12 | --schema_serialization_with_db_content true \ 13 | --normalize_query true \ 14 | --target_with_db_id false \ 15 | --metric_config both \ 16 | --output_dir experimental_outputs/spider/ \ 17 | --cache_dir transformers_cache \ 18 | --do_train true \ 19 | --do_eval false \ 20 | --fp16 false \ 21 | --num_train_epochs 100 \ 22 | --per_device_train_batch_size 2 \ 23 | --per_device_eval_batch_size 2 \ 24 | --gradient_accumulation_steps 16 \ 25 | --label_smoothing_factor 0.0 \ 26 | --learning_rate 0.1 \ 27 | --adafactor true \ 28 | --adam_eps 1e-6 \ 29 | --lr_scheduler_type constant \ 30 | --warmup_ratio 0.0 \ 31 | --warmup_steps 0 \ 32 | --seed 1 \ 33 | --logging_strategy steps \ 34 | --logging_steps 4 \ 35 | --metric_for_best_model exact_match \ 36 | --greater_is_better true \ 37 | --save_strategy steps \ 38 | --evaluation_strategy steps \ 39 | --predict_with_generate true \ 40 | --num_beams 1 \ 41 | --num_beam_groups 1 \ 42 | --use_constrained_decoding false \ 43 | --use_decomposition true \ 44 | --overwrite_output_dir true \ 45 | --stage structure \ 46 | --training_method PT \ 47 | --overwrite_cache true \ 48 | --train_samples_ratio $1 49 | 50 | python src/run.py \ 51 | --run_name t5-large \ 52 | --model_name_or_path t5-large \ 53 | --dataset spider \ 54 | --source_prefix "question: " \ 55 | --schema_serialization_type verbose \ 56 | --schema_serialization_randomized false \ 57 | --schema_serialization_with_db_id true \ 58 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 59 | --schema_serialization_with_db_content true \ 60 | --normalize_query true \ 61 | --target_with_db_id false \ 62 | --metric_config both \ 63 | --output_dir experimental_outputs/spider/ \ 64 | --cache_dir transformers_cache \ 65 | --do_train true \ 66 | --do_eval false \ 67 | --fp16 false \ 68 | --num_train_epochs 100 \ 69 | --per_device_train_batch_size 4 \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps 8 \ 72 | --label_smoothing_factor 0.0 \ 73 | --learning_rate 5e-5 \ 74 | --adafactor true \ 75 | --adam_eps 1e-6 \ 76 | --lr_scheduler_type constant \ 77 | --warmup_ratio 0.0 \ 78 | --warmup_steps 0 \ 79 | --seed 1 \ 80 | --logging_strategy steps \ 81 | --logging_steps 4 \ 82 | --metric_for_best_model exact_match \ 83 | --greater_is_better true \ 84 | --save_strategy steps \ 85 | --evaluation_strategy steps \ 86 | --predict_with_generate true \ 87 | --num_beams 1 \ 88 | --num_beam_groups 1 \ 89 | --use_constrained_decoding false \ 90 | --use_decomposition true \ 91 | --overwrite_output_dir true \ 92 | --stage structure \ 93 | --training_method PFT \ 94 | --overwrite_cache true \ 95 | --train_samples_ratio $1 96 | 97 | # 2. content stage 98 | # Note: To make the training process more stable, we first freeze the model to train learnable vectors. 99 | python src/run.py \ 100 | --run_name t5-large \ 101 | --model_name_or_path t5-large \ 102 | --dataset spider \ 103 | --source_prefix "question: " \ 104 | --schema_serialization_type verbose \ 105 | --schema_serialization_randomized false \ 106 | --schema_serialization_with_db_id true \ 107 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 108 | --schema_serialization_with_db_content true \ 109 | --normalize_query true \ 110 | --target_with_db_id false \ 111 | --metric_config both \ 112 | --output_dir experimental_outputs/spider/ \ 113 | --cache_dir transformers_cache \ 114 | --do_train true \ 115 | --do_eval false \ 116 | --fp16 false \ 117 | --num_train_epochs 900 \ 118 | --per_device_train_batch_size 2 \ 119 | --per_device_eval_batch_size 4 \ 120 | --gradient_accumulation_steps 16 \ 121 | --label_smoothing_factor 0.0 \ 122 | --learning_rate 0.1 \ 123 | --adafactor true \ 124 | --adam_eps 1e-6 \ 125 | --lr_scheduler_type constant \ 126 | --warmup_ratio 0.0 \ 127 | --warmup_steps 0 \ 128 | --seed 1 \ 129 | --logging_strategy steps \ 130 | --logging_steps 4 \ 131 | --metric_for_best_model exact_match \ 132 | --greater_is_better true \ 133 | --save_strategy steps \ 134 | --evaluation_strategy steps \ 135 | --predict_with_generate true \ 136 | --num_beams 1 \ 137 | --num_beam_groups 1 \ 138 | --use_constrained_decoding false \ 139 | --use_decomposition true \ 140 | --overwrite_output_dir true \ 141 | --stage content \ 142 | --training_method PT \ 143 | --overwrite_cache true \ 144 | --train_samples_ratio $1 145 | 146 | python src/run.py \ 147 | --run_name t5-large \ 148 | --model_name_or_path t5-large \ 149 | --dataset spider \ 150 | --source_prefix "question: " \ 151 | --schema_serialization_type verbose \ 152 | --schema_serialization_randomized false \ 153 | --schema_serialization_with_db_id true \ 154 | --schema_serialization_with_prompt "Translate the question into sql according to the database: " \ 155 | --schema_serialization_with_db_content true \ 156 | --normalize_query true \ 157 | --target_with_db_id false \ 158 | --metric_config both \ 159 | --output_dir experimental_outputs/spider/ \ 160 | --cache_dir transformers_cache \ 161 | --do_train true \ 162 | --do_eval false \ 163 | --fp16 false \ 164 | --num_train_epochs 100 \ 165 | --per_device_train_batch_size 4 \ 166 | --per_device_eval_batch_size 4 \ 167 | --gradient_accumulation_steps 8 \ 168 | --label_smoothing_factor 0.0 \ 169 | --learning_rate 5e-5 \ 170 | --adafactor true \ 171 | --adam_eps 1e-6 \ 172 | --lr_scheduler_type constant \ 173 | --warmup_ratio 0.0 \ 174 | --warmup_steps 0 \ 175 | --seed 1 \ 176 | --logging_strategy steps \ 177 | --logging_steps 4 \ 178 | --metric_for_best_model exact_match \ 179 | --greater_is_better true \ 180 | --save_strategy steps \ 181 | --evaluation_strategy steps \ 182 | --predict_with_generate true \ 183 | --num_beams 1 \ 184 | --num_beam_groups 1 \ 185 | --use_constrained_decoding false \ 186 | --use_decomposition true \ 187 | --overwrite_output_dir true \ 188 | --stage content \ 189 | --training_method PFT \ 190 | --overwrite_cache true \ 191 | --train_samples_ratio $1 -------------------------------------------------------------------------------- /src/datasets/cosql/cosql.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """CoSQL: A Conversational Text-to-SQL Challenge Towards Cross-Domain Natural Language Interfaces to Databases""" 16 | 17 | 18 | import json 19 | from third_party.spider.preprocess.get_tables import dump_db_json_schema 20 | import datasets 21 | 22 | 23 | logger = datasets.logging.get_logger(__name__) 24 | 25 | 26 | _CITATION = """\ 27 | @inproceedings{yu-etal-2019-cosql, 28 | title = "{C}o{SQL}: A Conversational Text-to-{SQL} Challenge Towards Cross-Domain Natural Language Interfaces to Databases", 29 | author = "Yu, Tao and 30 | Zhang, Rui and 31 | Er, Heyang and 32 | Li, Suyi and 33 | Xue, Eric and 34 | Pang, Bo and 35 | Lin, Xi Victoria and 36 | Tan, Yi Chern and 37 | Shi, Tianze and 38 | Li, Zihan and 39 | Jiang, Youxuan and 40 | Yasunaga, Michihiro and 41 | Shim, Sungrok and 42 | Chen, Tao and 43 | Fabbri, Alexander and 44 | Li, Zifan and 45 | Chen, Luyao and 46 | Zhang, Yuwen and 47 | Dixit, Shreya and 48 | Zhang, Vincent and 49 | Xiong, Caiming and 50 | Socher, Richard and 51 | Lasecki, Walter and 52 | Radev, Dragomir", 53 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 54 | month = nov, 55 | year = "2019", 56 | address = "Hong Kong, China", 57 | publisher = "Association for Computational Linguistics", 58 | url = "https://www.aclweb.org/anthology/D19-1204", 59 | doi = "10.18653/v1/D19-1204", 60 | pages = "1962--1979", 61 | abstract = "We present CoSQL, a corpus for building cross-domain, general-purpose database (DB) querying dialogue systems. It consists of 30k+ turns plus 10k+ annotated SQL queries, obtained from a Wizard-of-Oz (WOZ) collection of 3k dialogues querying 200 complex DBs spanning 138 domains. Each dialogue simulates a real-world DB query scenario with a crowd worker as a user exploring the DB and a SQL expert retrieving answers with SQL, clarifying ambiguous questions, or otherwise informing of unanswerable questions. When user questions are answerable by SQL, the expert describes the SQL and execution results to the user, hence maintaining a natural interaction flow. CoSQL introduces new challenges compared to existing task-oriented dialogue datasets: (1) the dialogue states are grounded in SQL, a domain-independent executable representation, instead of domain-specific slot value pairs, and (2) because testing is done on unseen databases, success requires generalizing to new domains. CoSQL includes three tasks: SQL-grounded dialogue state tracking, response generation from query results, and user dialogue act prediction. We evaluate a set of strong baselines for each task and show that CoSQL presents significant challenges for future research. The dataset, baselines, and leaderboard will be released at https://yale-lily.github.io/cosql.", 62 | } 63 | """ 64 | 65 | _DESCRIPTION = """\ 66 | CoSQL is a large-scale dataset for training and testing task oriented dialog agents with SQL 67 | """ 68 | 69 | _HOMEPAGE = "https://yale-lily.github.io/cosql" 70 | 71 | _LICENSE = "CC BY-SA 4.0" 72 | 73 | #_URL = "https://drive.google.com/uc?export=download&id=14x6lsWqlu6gR-aYxa6cemslDN3qT3zxP" 74 | _URL = "cosql_dataset.zip" 75 | 76 | class CoSQL(datasets.GeneratorBasedBuilder): 77 | VERSION = datasets.Version("1.0.0") 78 | 79 | BUILDER_CONFIGS = [ 80 | datasets.BuilderConfig( 81 | name="cosql", 82 | version=VERSION, 83 | description="A Conversational Text-to-SQL Challenge Towards Cross-Domain Natural Language Interfaces to Databases", 84 | ), 85 | ] 86 | 87 | def __init__(self, *args, writer_batch_size=None, **kwargs): 88 | super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs) 89 | self.schema_cache = dict() 90 | 91 | def _info(self): 92 | features = datasets.Features( 93 | { 94 | "query": datasets.Value("string"), 95 | "utterances": datasets.features.Sequence(datasets.Value("string")), 96 | "turn_idx": datasets.Value("int32"), 97 | "db_id": datasets.Value("string"), 98 | "db_path": datasets.Value("string"), 99 | "db_table_names": datasets.features.Sequence(datasets.Value("string")), 100 | "db_column_names": datasets.features.Sequence( 101 | { 102 | "table_id": datasets.Value("int32"), 103 | "column_name": datasets.Value("string"), 104 | } 105 | ), 106 | "db_column_types": datasets.features.Sequence(datasets.Value("string")), 107 | "db_primary_keys": datasets.features.Sequence({"column_id": datasets.Value("int32")}), 108 | "db_foreign_keys": datasets.features.Sequence( 109 | { 110 | "column_id": datasets.Value("int32"), 111 | "other_column_id": datasets.Value("int32"), 112 | } 113 | ), 114 | } 115 | ) 116 | return datasets.DatasetInfo( 117 | description=_DESCRIPTION, 118 | features=features, 119 | supervised_keys=None, 120 | homepage=_HOMEPAGE, 121 | license=_LICENSE, 122 | citation=_CITATION, 123 | ) 124 | 125 | def _split_generators(self, dl_manager): 126 | downloaded_filepath = dl_manager.download_and_extract(_URL) 127 | 128 | return [ 129 | datasets.SplitGenerator( 130 | name=datasets.Split.TRAIN, 131 | gen_kwargs={ 132 | "data_filepath": downloaded_filepath + "/cosql_dataset/sql_state_tracking/cosql_train.json", 133 | "db_path": downloaded_filepath + "/cosql_dataset/database", 134 | }, 135 | ), 136 | datasets.SplitGenerator( 137 | name=datasets.Split.VALIDATION, 138 | gen_kwargs={ 139 | "data_filepath": downloaded_filepath + "/cosql_dataset/sql_state_tracking/cosql_dev.json", 140 | "db_path": downloaded_filepath + "/cosql_dataset/database", 141 | }, 142 | ), 143 | ] 144 | 145 | def _generate_examples(self, data_filepath, db_path): 146 | """This function returns the examples in the raw (text) form.""" 147 | logger.info("generating examples from = %s", data_filepath) 148 | idx = 0 # indexing each training instance 149 | with open(data_filepath, encoding="utf-8") as f: 150 | cosql = json.load(f) 151 | for sample in cosql: 152 | db_id = sample["database_id"] 153 | if db_id not in self.schema_cache: 154 | self.schema_cache[db_id] = dump_db_json_schema( 155 | db_path + "/" + db_id + "/" + db_id + ".sqlite", db_id 156 | ) 157 | schema = self.schema_cache[db_id] 158 | 159 | db_info = { 160 | "db_id": db_id, 161 | "db_path": db_path, 162 | "db_table_names": schema["table_names_original"], 163 | "db_column_names": [ 164 | {"table_id": table_id, "column_name": column_name} 165 | for table_id, column_name in schema["column_names_original"] 166 | ], 167 | "db_column_types": schema["column_types"], 168 | "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]], 169 | "db_foreign_keys": [ 170 | {"column_id": column_id, "other_column_id": other_column_id} 171 | for column_id, other_column_id in schema["foreign_keys"] 172 | ], 173 | } 174 | 175 | yield idx, { 176 | "utterances": [sample["final"]["utterance"]], 177 | "query": sample["final"]["query"], 178 | "turn_idx": -1, 179 | **db_info, 180 | } 181 | idx += 1 182 | utterances = [] 183 | for turn_idx, turn in enumerate(sample["interaction"]): 184 | utterances.extend((utterance.strip() for utterance in turn["utterance"].split(sep="|"))) 185 | yield idx, { 186 | "utterances": list(utterances), 187 | "query": turn["query"], 188 | "turn_idx": turn_idx, 189 | **db_info, 190 | } 191 | idx += 1 192 | -------------------------------------------------------------------------------- /src/datasets/geoquery/geoquery.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Geoquery: Learning to Parse Database Queries Using Inductive Logic Programming""" 16 | 17 | import json 18 | import os 19 | from typing import List, Generator, Any, Dict, Tuple 20 | from third_party.spider.preprocess.get_tables import dump_db_json_schema 21 | import datasets 22 | 23 | 24 | logger = datasets.logging.get_logger(__name__) 25 | 26 | 27 | _CITATION = """\ 28 | @inproceedings{data-geography-original 29 | dataset = {Geography, original}, 30 | author = {John M. Zelle and Raymond J. Mooney}, 31 | title = {Learning to Parse Database Queries Using Inductive Logic Programming}, 32 | booktitle = {Proceedings of the Thirteenth National Conference on Artificial Intelligence - Volume 2}, 33 | year = {1996}, 34 | pages = {1050--1055}, 35 | location = {Portland, Oregon}, 36 | url = {http://dl.acm.org/citation.cfm?id=1864519.1864543}, 37 | } 38 | """ 39 | 40 | _DESCRIPTION = """\ 41 | Geoquery contains 880 queries and a database of U.S. geography. 42 | """ 43 | 44 | _HOMEPAGE = "" 45 | 46 | _LICENSE = "CC BY-SA 4.0" 47 | 48 | _URL = "geoquery.zip" 49 | 50 | def normalize_alias( 51 | sql: str, 52 | table_names: List[str], 53 | ) -> str: 54 | alias_format = 'T{count}' 55 | count = 1 56 | for tab in table_names+['DERIVED_FIELD', 'DERIVED_TABLE']: 57 | tab = tab.upper() 58 | for idx in ['0','1','2','3','4','5','6']: 59 | old_alias = tab+'alias'+idx 60 | if old_alias in sql: 61 | new_alias = alias_format.format(count=count) 62 | sql = sql.replace(old_alias, new_alias) 63 | count += 1 64 | return sql 65 | 66 | class Spider(datasets.GeneratorBasedBuilder): 67 | VERSION = datasets.Version("1.0.0") 68 | 69 | BUILDER_CONFIGS = [ 70 | datasets.BuilderConfig( 71 | name="Geoquery", 72 | version=VERSION, 73 | description="880 queries and a database of U.S. geography", 74 | ), 75 | ] 76 | 77 | def __init__(self, *args, writer_batch_size=None, **kwargs) -> None: 78 | super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs) 79 | self.schema_cache = dict() 80 | self.include_train_others: bool = kwargs.pop("include_train_others", False) 81 | 82 | def _info(self) -> datasets.DatasetInfo: 83 | features = datasets.Features( 84 | { 85 | "query": datasets.Value("string"), 86 | "query_toks": datasets.features.Sequence(datasets.Value("string")), 87 | "question": datasets.Value("string"), 88 | "db_id": datasets.Value("string"), 89 | "db_path": datasets.Value("string"), 90 | "db_table_names": datasets.features.Sequence(datasets.Value("string")), 91 | "db_column_names": datasets.features.Sequence( 92 | { 93 | "table_id": datasets.Value("int32"), 94 | "column_name": datasets.Value("string"), 95 | } 96 | ), 97 | "db_column_types": datasets.features.Sequence(datasets.Value("string")), 98 | "db_primary_keys": datasets.features.Sequence({"column_id": datasets.Value("int32")}), 99 | "db_foreign_keys": datasets.features.Sequence( 100 | { 101 | "column_id": datasets.Value("int32"), 102 | "other_column_id": datasets.Value("int32"), 103 | } 104 | ), 105 | } 106 | ) 107 | return datasets.DatasetInfo( 108 | description=_DESCRIPTION, 109 | features=features, 110 | supervised_keys=None, 111 | homepage=_HOMEPAGE, 112 | license=_LICENSE, 113 | citation=_CITATION, 114 | ) 115 | 116 | def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: 117 | downloaded_filepath = dl_manager.download_and_extract(url_or_urls=_URL) 118 | 119 | return [ 120 | datasets.SplitGenerator( 121 | name=datasets.Split.TRAIN, 122 | gen_kwargs={ 123 | "data_filepaths": [os.path.join(downloaded_filepath, "geoquery/train.json")], 124 | "db_path": os.path.join(downloaded_filepath, "geoquery/database"), 125 | }, 126 | ), 127 | datasets.SplitGenerator( 128 | name=datasets.Split.VALIDATION, 129 | gen_kwargs={ 130 | "data_filepaths": [os.path.join(downloaded_filepath, "geoquery/test.json")], 131 | "db_path": os.path.join(downloaded_filepath, "geoquery/database"), 132 | }, 133 | ), 134 | ] 135 | 136 | def _generate_examples( 137 | self, data_filepaths: List[str], db_path: str 138 | ) -> Generator[Tuple[int, Dict[str, Any]], None, None]: 139 | """This function returns the examples in the raw (text) form.""" 140 | print(f'db_path={db_path}') 141 | for data_filepath in data_filepaths: 142 | logger.info("generating examples from = %s", data_filepath) 143 | with open(data_filepath, encoding="utf-8") as f: 144 | geoquery = json.load(f) 145 | for idx, sample in enumerate(geoquery): 146 | db_id = 'geo' 147 | if db_id not in self.schema_cache: 148 | self.schema_cache[db_id] = dump_db_json_schema( 149 | db=os.path.join(db_path, db_id, f"{db_id}.sqlite"), f=db_id 150 | ) 151 | schema = self.schema_cache[db_id] 152 | sample['sql'] = normalize_alias(sample["sql"], schema["table_names_original"]) 153 | yield idx, { 154 | "query": sample['sql'], 155 | "query_toks": sample["sql"].split(), 156 | "question": sample["query"], 157 | "db_id": db_id, 158 | "db_path": db_path, 159 | "db_table_names": schema["table_names_original"], 160 | "db_column_names": [ 161 | {"table_id": table_id, "column_name": column_name} 162 | for table_id, column_name in schema["column_names_original"] 163 | ], 164 | "db_column_types": schema["column_types"], 165 | "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]], 166 | "db_foreign_keys": [ 167 | {"column_id": column_id, "other_column_id": other_column_id} 168 | for column_id, other_column_id in schema["foreign_keys"] 169 | ], 170 | } 171 | 172 | -------------------------------------------------------------------------------- /src/datasets/spider/spider.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Spider: A Large-Scale Human-Labeled Dataset for Text-to-SQL Tasks""" 16 | 17 | import json 18 | import os 19 | from typing import List, Generator, Any, Dict, Tuple 20 | from third_party.spider.preprocess.get_tables import dump_db_json_schema 21 | import datasets 22 | 23 | 24 | logger = datasets.logging.get_logger(__name__) 25 | 26 | 27 | _CITATION = """\ 28 | @article{yu2018spider, 29 | title={Spider: A large-scale human-labeled dataset for complex and cross-domain semantic parsing and text-to-sql task}, 30 | author={Yu, Tao and Zhang, Rui and Yang, Kai and Yasunaga, Michihiro and Wang, Dongxu and Li, Zifan and Ma, James and Li, Irene and Yao, Qingning and Roman, Shanelle and others}, 31 | journal={arXiv preprint arXiv:1809.08887}, 32 | year={2018} 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """\ 37 | Spider is a large-scale complex and cross-domain semantic parsing and text-toSQL dataset annotated by 11 college students 38 | """ 39 | 40 | _HOMEPAGE = "https://yale-lily.github.io/spider" 41 | 42 | _LICENSE = "CC BY-SA 4.0" 43 | 44 | #_URL = "https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0" 45 | _URL = "spider.zip" 46 | 47 | 48 | class Spider(datasets.GeneratorBasedBuilder): 49 | VERSION = datasets.Version("1.0.0") 50 | 51 | BUILDER_CONFIGS = [ 52 | datasets.BuilderConfig( 53 | name="spider", 54 | version=VERSION, 55 | description="Spider: A Large-Scale Human-Labeled Dataset for Text-to-SQL Tasks", 56 | ), 57 | ] 58 | 59 | def __init__(self, *args, writer_batch_size=None, **kwargs) -> None: 60 | super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs) 61 | self.schema_cache = dict() 62 | self.include_train_others: bool = kwargs.pop("include_train_others", False) 63 | 64 | def _info(self) -> datasets.DatasetInfo: 65 | features = datasets.Features( 66 | { 67 | "query": datasets.Value("string"), 68 | "query_toks": datasets.features.Sequence(datasets.Value("string")), 69 | "question": datasets.Value("string"), 70 | "db_id": datasets.Value("string"), 71 | "db_path": datasets.Value("string"), 72 | "db_table_names": datasets.features.Sequence(datasets.Value("string")), 73 | "db_column_names": datasets.features.Sequence( 74 | { 75 | "table_id": datasets.Value("int32"), 76 | "column_name": datasets.Value("string"), 77 | } 78 | ), 79 | "db_column_types": datasets.features.Sequence(datasets.Value("string")), 80 | "db_primary_keys": datasets.features.Sequence({"column_id": datasets.Value("int32")}), 81 | "db_foreign_keys": datasets.features.Sequence( 82 | { 83 | "column_id": datasets.Value("int32"), 84 | "other_column_id": datasets.Value("int32"), 85 | } 86 | ), 87 | } 88 | ) 89 | return datasets.DatasetInfo( 90 | description=_DESCRIPTION, 91 | features=features, 92 | supervised_keys=None, 93 | homepage=_HOMEPAGE, 94 | license=_LICENSE, 95 | citation=_CITATION, 96 | ) 97 | 98 | def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: 99 | downloaded_filepath = dl_manager.download_and_extract(url_or_urls=_URL) 100 | 101 | return [ 102 | datasets.SplitGenerator( 103 | name=datasets.Split.TRAIN, 104 | gen_kwargs={ 105 | "data_filepaths": [ 106 | os.path.join(downloaded_filepath, "spider/train_spider.json"), 107 | os.path.join(downloaded_filepath, "spider/train_others.json"), 108 | ] 109 | if self.include_train_others 110 | else [os.path.join(downloaded_filepath, "spider/train_spider.json")], 111 | "db_path": os.path.join(downloaded_filepath, "spider/database"), 112 | }, 113 | ), 114 | datasets.SplitGenerator( 115 | name=datasets.Split.VALIDATION, 116 | gen_kwargs={ 117 | "data_filepaths": [os.path.join(downloaded_filepath, "spider/dev.json")], 118 | "db_path": os.path.join(downloaded_filepath, "spider/database"), 119 | }, 120 | ), 121 | ] 122 | 123 | def _generate_examples( 124 | self, data_filepaths: List[str], db_path: str 125 | ) -> Generator[Tuple[int, Dict[str, Any]], None, None]: 126 | """This function returns the examples in the raw (text) form.""" 127 | for data_filepath in data_filepaths: 128 | logger.info("generating examples from = %s", data_filepath) 129 | with open(data_filepath, encoding="utf-8") as f: 130 | spider = json.load(f) 131 | for idx, sample in enumerate(spider): 132 | db_id = sample["db_id"] 133 | if db_id not in self.schema_cache: 134 | self.schema_cache[db_id] = dump_db_json_schema( 135 | db=os.path.join(db_path, db_id, f"{db_id}.sqlite"), f=db_id 136 | ) 137 | schema = self.schema_cache[db_id] 138 | yield idx, { 139 | "query": sample["query"], 140 | "query_toks": sample["query_toks"], 141 | "question": sample["question"], 142 | "db_id": db_id, 143 | "db_path": db_path, 144 | "db_table_names": schema["table_names_original"], 145 | "db_column_names": [ 146 | {"table_id": table_id, "column_name": column_name} 147 | for table_id, column_name in schema["column_names_original"] 148 | ], 149 | "db_column_types": schema["column_types"], 150 | "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]], 151 | "db_foreign_keys": [ 152 | {"column_id": column_id, "other_column_id": other_column_id} 153 | for column_id, other_column_id in schema["foreign_keys"] 154 | ], 155 | } 156 | -------------------------------------------------------------------------------- /src/metrics/cosql/cosql.py: -------------------------------------------------------------------------------- 1 | """Spider metrics.""" 2 | 3 | from typing import Optional, Union 4 | from .spider_test_suite import compute_test_suite_metric 5 | from .spider_exact_match import compute_exact_match_metric 6 | import datasets 7 | 8 | 9 | _DESCRIPTION = """ 10 | Spider metrics. 11 | """ 12 | 13 | _KWARGS_DESCRIPTION = """ 14 | """ 15 | 16 | _CITATION = """\ 17 | @article{yu2018spider, 18 | title={Spider: A large-scale human-labeled dataset for complex and cross-domain semantic parsing and text-to-sql task}, 19 | author={Yu, Tao and Zhang, Rui and Yang, Kai and Yasunaga, Michihiro and Wang, Dongxu and Li, Zifan and Ma, James and Li, Irene and Yao, Qingning and Roman, Shanelle and others}, 20 | journal={arXiv preprint arXiv:1809.08887}, 21 | year={2018} 22 | } 23 | @misc{zhong2020semantic, 24 | title={Semantic Evaluation for Text-to-SQL with Distilled Test Suites}, 25 | author={Ruiqi Zhong and Tao Yu and Dan Klein}, 26 | year={2020}, 27 | eprint={2010.02840}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CL} 30 | } 31 | """ 32 | 33 | _URL = "https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0" 34 | 35 | 36 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 37 | class CoSQL(datasets.Metric): 38 | def __init__( 39 | self, 40 | config_name: Optional[str] = None, 41 | keep_in_memory: bool = False, 42 | cache_dir: Optional[str] = None, 43 | num_process: int = 1, 44 | process_id: int = 0, 45 | seed: Optional[int] = None, 46 | experiment_id: Optional[str] = None, 47 | max_concurrent_cache_files: int = 10000, 48 | timeout: Union[int, float] = 100, 49 | **kwargs 50 | ): 51 | super().__init__( 52 | config_name=config_name, 53 | keep_in_memory=keep_in_memory, 54 | cache_dir=cache_dir, 55 | num_process=num_process, 56 | process_id=process_id, 57 | seed=seed, 58 | experiment_id=experiment_id, 59 | max_concurrent_cache_files=max_concurrent_cache_files, 60 | timeout=timeout, 61 | **kwargs 62 | ) 63 | self.test_suite_db_dir: Optional[str] = kwargs.pop("test_suite_db_dir", None) 64 | 65 | def _info(self): 66 | if self.config_name not in [ 67 | "exact_match", 68 | "test_suite", 69 | "both", 70 | ]: 71 | raise KeyError( 72 | "You should supply a configuration name selected in " '["exact_match", "test_suite", "both"]' 73 | ) 74 | return datasets.MetricInfo( 75 | description=_DESCRIPTION, 76 | citation=_CITATION, 77 | inputs_description=_KWARGS_DESCRIPTION, 78 | features=datasets.Features( 79 | { 80 | "predictions": datasets.Value("string"), 81 | "references": { 82 | "query": datasets.Value("string"), 83 | "utterances": datasets.features.Sequence(datasets.Value("string")), 84 | "turn_idx": datasets.Value("int32"), 85 | "context": datasets.Value("string"), 86 | "label": datasets.Value("string"), 87 | "db_id": datasets.Value("string"), 88 | "db_path": datasets.Value("string"), 89 | "db_table_names": datasets.features.Sequence(datasets.Value("string")), 90 | "db_column_names": datasets.features.Sequence( 91 | { 92 | "table_id": datasets.Value("int32"), 93 | "column_name": datasets.Value("string"), 94 | } 95 | ), 96 | "db_foreign_keys": datasets.features.Sequence( 97 | { 98 | "column_id": datasets.Value("int32"), 99 | "other_column_id": datasets.Value("int32"), 100 | } 101 | ), 102 | }, 103 | } 104 | ), 105 | reference_urls=[_URL], 106 | ) 107 | 108 | def _compute(self, predictions, references): 109 | if self.config_name == "exact_match" or self.config_name == "both": 110 | exact_match = compute_exact_match_metric(predictions, references) 111 | else: 112 | exact_match = dict() 113 | 114 | if self.config_name == "test_suite" or self.config_name == "both": 115 | test_suite = compute_test_suite_metric(predictions, references, db_dir=self.test_suite_db_dir) 116 | else: 117 | test_suite = dict() 118 | 119 | return {**exact_match, **test_suite} 120 | -------------------------------------------------------------------------------- /src/metrics/cosql/spider_exact_match.py: -------------------------------------------------------------------------------- 1 | """Spider exact match metric.""" 2 | 3 | from typing import Dict, Any 4 | from third_party.spider import evaluation as spider_evaluation 5 | 6 | 7 | def compute_exact_match_metric(predictions, references) -> Dict[str, Any]: 8 | foreign_key_maps = dict() 9 | for reference in references: 10 | if reference["db_id"] not in foreign_key_maps: 11 | foreign_key_maps[reference["db_id"]] = spider_evaluation.build_foreign_key_map( 12 | { 13 | "table_names_original": reference["db_table_names"], 14 | "column_names_original": list( 15 | zip( 16 | reference["db_column_names"]["table_id"], 17 | reference["db_column_names"]["column_name"], 18 | ) 19 | ), 20 | "foreign_keys": list( 21 | zip( 22 | reference["db_foreign_keys"]["column_id"], 23 | reference["db_foreign_keys"]["other_column_id"], 24 | ) 25 | ), 26 | } 27 | ) 28 | evaluator = spider_evaluation.Evaluator(references[0]["db_path"], foreign_key_maps, "match") 29 | for prediction, reference in zip(predictions, references): 30 | turn_idx = reference.get("turn_idx", 0) 31 | # skip final utterance-query pairs 32 | if turn_idx < 0: 33 | continue 34 | _ = evaluator.evaluate_one(reference["db_id"], reference["query"], prediction) 35 | evaluator.finalize() 36 | return { 37 | "exact_match": evaluator.scores["all"]["exact"], 38 | } 39 | -------------------------------------------------------------------------------- /src/metrics/cosql/spider_test_suite.py: -------------------------------------------------------------------------------- 1 | """Spider Test Suite Execution Accuracy metric.""" 2 | import logging 3 | from typing import Optional, Dict, Any 4 | from third_party.test_suite import evaluation as test_suite_evaluation 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def compute_test_suite_metric(predictions, references, db_dir: Optional[str] = None) -> Dict[str, Any]: 10 | if db_dir is None: 11 | references[0]["db_path"] 12 | 13 | foreign_key_maps = dict() 14 | for reference in references: 15 | if reference["db_id"] not in foreign_key_maps: 16 | foreign_key_maps[reference["db_id"]] = test_suite_evaluation.build_foreign_key_map( 17 | { 18 | "table_names_original": reference["db_table_names"], 19 | "column_names_original": list( 20 | zip( 21 | reference["db_column_names"]["table_id"], 22 | reference["db_column_names"]["column_name"], 23 | ) 24 | ), 25 | "foreign_keys": list( 26 | zip( 27 | reference["db_foreign_keys"]["column_id"], 28 | reference["db_foreign_keys"]["other_column_id"], 29 | ) 30 | ), 31 | } 32 | ) 33 | 34 | evaluator = test_suite_evaluation.Evaluator( 35 | db_dir=db_dir if db_dir is not None else references[0]["db_path"], 36 | kmaps=foreign_key_maps, 37 | etype="exec", 38 | plug_value=False, 39 | keep_distinct=False, 40 | progress_bar_for_each_datapoint=False, 41 | ) 42 | # Only used for Sparc/CoSQL 43 | turn_scores = {"exec": [], "exact": []} 44 | for prediction, reference in zip(predictions, references): 45 | turn_idx = reference.get("turn_idx", 0) 46 | # skip final utterance-query pairs 47 | if turn_idx < 0: 48 | continue 49 | try: 50 | _ = evaluator.evaluate_one( 51 | reference["db_id"], 52 | reference["query"], 53 | prediction, 54 | turn_scores, 55 | idx=turn_idx, 56 | ) 57 | except AssertionError as e: 58 | logger.warning(f"unexpected evaluation error: {e.args[0]}") 59 | evaluator.finalize() 60 | return { 61 | "exec": evaluator.scores["all"]["exec"], 62 | } 63 | -------------------------------------------------------------------------------- /src/metrics/spider/spider.py: -------------------------------------------------------------------------------- 1 | """Spider metrics.""" 2 | 3 | from typing import Optional, Union 4 | from .spider_test_suite import compute_test_suite_metric 5 | from .spider_exact_match import compute_exact_match_metric 6 | import datasets 7 | 8 | 9 | _DESCRIPTION = """ 10 | Spider metrics. 11 | """ 12 | 13 | _KWARGS_DESCRIPTION = """ 14 | """ 15 | 16 | _CITATION = """\ 17 | @article{yu2018spider, 18 | title={Spider: A large-scale human-labeled dataset for complex and cross-domain semantic parsing and text-to-sql task}, 19 | author={Yu, Tao and Zhang, Rui and Yang, Kai and Yasunaga, Michihiro and Wang, Dongxu and Li, Zifan and Ma, James and Li, Irene and Yao, Qingning and Roman, Shanelle and others}, 20 | journal={arXiv preprint arXiv:1809.08887}, 21 | year={2018} 22 | } 23 | @misc{zhong2020semantic, 24 | title={Semantic Evaluation for Text-to-SQL with Distilled Test Suites}, 25 | author={Ruiqi Zhong and Tao Yu and Dan Klein}, 26 | year={2020}, 27 | eprint={2010.02840}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CL} 30 | } 31 | """ 32 | 33 | _URL = "https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0" 34 | 35 | 36 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 37 | class Spider(datasets.Metric): 38 | def __init__( 39 | self, 40 | config_name: Optional[str] = None, 41 | keep_in_memory: bool = False, 42 | cache_dir: Optional[str] = None, 43 | num_process: int = 1, 44 | process_id: int = 0, 45 | seed: Optional[int] = None, 46 | experiment_id: Optional[str] = None, 47 | max_concurrent_cache_files: int = 10000, 48 | timeout: Union[int, float] = 100, 49 | **kwargs 50 | ): 51 | super().__init__( 52 | config_name=config_name, 53 | keep_in_memory=keep_in_memory, 54 | cache_dir=cache_dir, 55 | num_process=num_process, 56 | process_id=process_id, 57 | seed=seed, 58 | experiment_id=experiment_id, 59 | max_concurrent_cache_files=max_concurrent_cache_files, 60 | timeout=timeout, 61 | **kwargs 62 | ) 63 | self.test_suite_db_dir: Optional[str] = kwargs.pop("test_suite_db_dir", None) 64 | 65 | def _info(self): 66 | if self.config_name not in [ 67 | "exact_match", 68 | "test_suite", 69 | "both", 70 | ]: 71 | raise KeyError( 72 | "You should supply a configuration name selected in " '["exact_match", "test_suite", "both"]' 73 | ) 74 | 75 | return datasets.MetricInfo( 76 | description=_DESCRIPTION, 77 | citation=_CITATION, 78 | inputs_description=_KWARGS_DESCRIPTION, 79 | 80 | features=datasets.Features( 81 | { 82 | "predictions": datasets.Value("string"), 83 | "references": { 84 | "query": datasets.Value("string"), 85 | "question": datasets.Value("string"), 86 | "context": datasets.Value("string"), 87 | "label": datasets.Value("string"), 88 | "db_id": datasets.Value("string"), 89 | "db_path": datasets.Value("string"), 90 | "db_table_names": datasets.features.Sequence(datasets.Value("string")), 91 | "db_column_names": datasets.features.Sequence( 92 | { 93 | "table_id": datasets.Value("int32"), 94 | "column_name": datasets.Value("string"), 95 | } 96 | ), 97 | "db_foreign_keys": datasets.features.Sequence( 98 | { 99 | "column_id": datasets.Value("int32"), 100 | "other_column_id": datasets.Value("int32"), 101 | } 102 | ), 103 | }, 104 | } 105 | ), 106 | reference_urls=[_URL], 107 | ) 108 | ''' 109 | return datasets.MetricInfo( 110 | description=_DESCRIPTION, 111 | citation=_CITATION, 112 | inputs_description=_KWARGS_DESCRIPTION, 113 | 114 | features=datasets.Features( 115 | { 116 | "predictions": datasets.Value("string"), 117 | "references": datasets.Value("string"), 118 | } 119 | ), 120 | reference_urls=[_URL], 121 | ) 122 | ''' 123 | 124 | def _compute(self, predictions, references): 125 | if self.config_name == "exact_match" or self.config_name == "both": 126 | exact_match = compute_exact_match_metric(predictions, references) 127 | else: 128 | exact_match = dict() 129 | 130 | if self.config_name == "test_suite" or self.config_name == "both": 131 | test_suite = compute_test_suite_metric(predictions, references, db_dir=self.test_suite_db_dir) 132 | else: 133 | test_suite = dict() 134 | 135 | return {**exact_match, **test_suite} 136 | -------------------------------------------------------------------------------- /src/metrics/spider/spider_exact_match.py: -------------------------------------------------------------------------------- 1 | """Spider exact match metric.""" 2 | 3 | from typing import Dict, Any 4 | from third_party.spider import evaluation as spider_evaluation 5 | 6 | 7 | def compute_exact_match_metric(predictions, references) -> Dict[str, Any]: 8 | foreign_key_maps = dict() 9 | for reference in references: 10 | if reference["db_id"] not in foreign_key_maps: 11 | foreign_key_maps[reference["db_id"]] = spider_evaluation.build_foreign_key_map( 12 | { 13 | "table_names_original": reference["db_table_names"], 14 | "column_names_original": list( 15 | zip( 16 | reference["db_column_names"]["table_id"], 17 | reference["db_column_names"]["column_name"], 18 | ) 19 | ), 20 | "foreign_keys": list( 21 | zip( 22 | reference["db_foreign_keys"]["column_id"], 23 | reference["db_foreign_keys"]["other_column_id"], 24 | ) 25 | ), 26 | } 27 | ) 28 | evaluator = spider_evaluation.Evaluator(references[0]["db_path"], foreign_key_maps, "match") 29 | for prediction, reference in zip(predictions, references): 30 | turn_idx = reference.get("turn_idx", 0) 31 | # skip final utterance-query pairs 32 | if turn_idx < 0: 33 | continue 34 | _ = evaluator.evaluate_one(reference["db_id"], reference["query"], prediction) 35 | evaluator.finalize() 36 | return { 37 | "exact_match": evaluator.scores["all"]["exact"], 38 | } 39 | -------------------------------------------------------------------------------- /src/metrics/spider/spider_test_suite.py: -------------------------------------------------------------------------------- 1 | """Spider Test Suite Execution Accuracy metric.""" 2 | import logging 3 | from typing import Optional, Dict, Any 4 | from third_party.test_suite import evaluation as test_suite_evaluation 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def compute_test_suite_metric(predictions, references, db_dir: Optional[str] = None) -> Dict[str, Any]: 10 | if db_dir is None: 11 | references[0]["db_path"] 12 | 13 | foreign_key_maps = dict() 14 | for reference in references: 15 | if reference["db_id"] not in foreign_key_maps: 16 | foreign_key_maps[reference["db_id"]] = test_suite_evaluation.build_foreign_key_map( 17 | { 18 | "table_names_original": reference["db_table_names"], 19 | "column_names_original": list( 20 | zip( 21 | reference["db_column_names"]["table_id"], 22 | reference["db_column_names"]["column_name"], 23 | ) 24 | ), 25 | "foreign_keys": list( 26 | zip( 27 | reference["db_foreign_keys"]["column_id"], 28 | reference["db_foreign_keys"]["other_column_id"], 29 | ) 30 | ), 31 | } 32 | ) 33 | 34 | evaluator = test_suite_evaluation.Evaluator( 35 | db_dir=db_dir if db_dir is not None else references[0]["db_path"], 36 | kmaps=foreign_key_maps, 37 | etype="exec", 38 | plug_value=False, 39 | keep_distinct=False, 40 | progress_bar_for_each_datapoint=False, 41 | ) 42 | # Only used for Sparc/CoSQL 43 | turn_scores = {"exec": [], "exact": []} 44 | for prediction, reference in zip(predictions, references): 45 | turn_idx = reference.get("turn_idx", 0) 46 | # skip final utterance-query pairs 47 | if turn_idx < 0: 48 | continue 49 | try: 50 | _ = evaluator.evaluate_one( 51 | reference["db_id"], 52 | reference["query"], 53 | prediction, 54 | turn_scores, 55 | idx=turn_idx, 56 | ) 57 | except AssertionError as e: 58 | logger.warning(f"unexpected evaluation error: {e.args[0]}") 59 | evaluator.finalize() 60 | return { 61 | "exec": evaluator.scores["all"]["exec"], 62 | } 63 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | # Set up logging 2 | import sys 3 | import logging 4 | 5 | logging.basicConfig( 6 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 7 | datefmt="%m/%d/%Y %H:%M:%S", 8 | handlers=[logging.StreamHandler(sys.stdout)], 9 | level=logging.WARNING, 10 | ) 11 | logger = logging.getLogger(__name__) 12 | 13 | import os 14 | import json 15 | from pathlib import Path 16 | import torch 17 | from contextlib import nullcontext 18 | from dataclasses import asdict, fields 19 | from transformers.hf_argparser import HfArgumentParser 20 | from transformers.training_args_seq2seq import Seq2SeqTrainingArguments 21 | from transformers.models.auto import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM 22 | from transformers.data.data_collator import DataCollatorForSeq2Seq 23 | from transformers.trainer_utils import get_last_checkpoint, set_seed 24 | from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration 25 | from transformers.models.t5.tokenization_t5_fast import T5TokenizerFast 26 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 27 | from tokenizers import AddedToken 28 | from utils.args import ModelArguments, DataTrainingArguments, DataArguments 29 | from utils.decode_wrapper import PicardArguments, PicardLauncher, with_picard 30 | from utils.dataset_loader import load_dataset 31 | from utils.spider import SpiderTrainer 32 | from utils.cosql import CoSQLTrainer 33 | from utils.geoquery import GeoQueryTrainer 34 | from utils.PT_wrapper import PromptWrapper 35 | 36 | set_seed(1) 37 | def main() -> None: 38 | # See all possible arguments by passing the --help flag to this script. 39 | parser = HfArgumentParser( 40 | (PicardArguments, ModelArguments, DataArguments, DataTrainingArguments, Seq2SeqTrainingArguments) 41 | ) 42 | picard_args: PicardArguments 43 | model_args: ModelArguments 44 | data_args: DataArguments 45 | data_training_args: DataTrainingArguments 46 | training_args: Seq2SeqTrainingArguments 47 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 48 | # If we pass only one argument to the script and it's the path to a json file, 49 | # let's parse it to get our arguments. 50 | picard_args, model_args, data_args, data_training_args, training_args = parser.parse_json_file( 51 | json_file=os.path.abspath(sys.argv[1]) 52 | ) 53 | elif len(sys.argv) == 3 and sys.argv[1].startswith("--local_rank") and sys.argv[2].endswith(".json"): 54 | data = json.loads(Path(os.path.abspath(sys.argv[2])).read_text()) 55 | data.update({"local_rank": int(sys.argv[1].split("=")[1])}) 56 | picard_args, model_args, data_args, data_training_args, training_args = parser.parse_dict(args=data) 57 | else: 58 | picard_args, model_args, data_args, data_training_args, training_args = parser.parse_args_into_dataclasses() 59 | 60 | 61 | combined_args_dict = { 62 | **asdict(picard_args), 63 | **asdict(model_args), 64 | **asdict(data_args), 65 | **asdict(data_training_args), 66 | **training_args.to_sanitized_dict(), 67 | } 68 | combined_args_dict.pop("local_rank", None) 69 | 70 | # Detect last checkpoint 71 | last_checkpoint = None 72 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 73 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 74 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 75 | raise ValueError( 76 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 77 | "Use --overwrite_output_dir to overcome." 78 | ) 79 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 80 | logger.info( 81 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 82 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 83 | ) 84 | training_args.report_to = [] 85 | # set path 86 | fewshot_identifier = f'R{str(int(data_training_args.train_samples_ratio*100))}' 87 | training_args.output_dir = os.path.join(training_args.output_dir, f'{fewshot_identifier}_results') 88 | if training_args.do_train: 89 | if data_training_args.structure_path == "": 90 | data_training_args.structure_path = os.path.join(training_args.output_dir, f'structure/structure.json') 91 | if data_training_args.training_method == "PFT" and data_training_args.initial_vectors_path == "": 92 | data_training_args.initial_vectors_path = os.path.join(training_args.output_dir, f'{data_training_args.stage}/head.npy') 93 | if data_training_args.use_decomposition: 94 | training_args.output_dir = os.path.join(training_args.output_dir, data_training_args.stage) 95 | else: 96 | training_args.output_dir = os.path.join(training_args.output_dir, 'seq2seq') 97 | elif training_args.do_eval: 98 | model_args.model_name_or_path = os.path.join(training_args.output_dir, f'{data_training_args.stage}/BEST_MODEL') 99 | training_args.output_dir = os.path.join(training_args.output_dir, 'prediction') 100 | if data_training_args.stage == "content" and os.path.exists(os.path.join(training_args.output_dir, "hypotheses.json")): 101 | os.remove(os.path.join(training_args.output_dir, "hypotheses.json")) 102 | if data_training_args.structure_path == "": 103 | data_training_args.structure_path = os.path.join(training_args.output_dir, f'structure.json') 104 | if data_training_args.training_method == "PFT" and data_training_args.initial_vectors_path == "": 105 | data_training_args.initial_vectors_path = os.path.join(model_args.model_name_or_path, 'head.npy') 106 | 107 | os.makedirs(training_args.output_dir, exist_ok=True) 108 | 109 | if training_args.local_rank <= 0: 110 | with open(f"{training_args.output_dir}/combined_args.json", "w") as f: 111 | json.dump(combined_args_dict, f, indent=4) 112 | 113 | # Initialize config 114 | config = AutoConfig.from_pretrained( 115 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 116 | cache_dir=model_args.cache_dir, 117 | revision=model_args.model_revision, 118 | use_auth_token=True if model_args.use_auth_token else None, 119 | max_length=data_training_args.max_target_length, 120 | num_beams=data_training_args.num_beams, 121 | num_beam_groups=data_training_args.num_beam_groups, 122 | diversity_penalty=data_training_args.diversity_penalty, 123 | gradient_checkpointing=training_args.gradient_checkpointing, 124 | use_cache=not training_args.gradient_checkpointing, 125 | num_return_sequences=data_training_args.num_beams if data_training_args.use_constrained_decoding and data_training_args.stage == "content" else 1, 126 | ) 127 | 128 | # Initialize tokenizer 129 | tokenizer = AutoTokenizer.from_pretrained( 130 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 131 | cache_dir=model_args.cache_dir, 132 | use_fast=model_args.use_fast_tokenizer, 133 | revision=model_args.model_revision, 134 | use_auth_token=True if model_args.use_auth_token else None, 135 | ) 136 | 137 | assert isinstance(tokenizer, PreTrainedTokenizerFast), "Only fast tokenizers are currently supported" 138 | if isinstance(tokenizer, T5TokenizerFast): 139 | # In T5 `<` is OOV, see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/restore_oov.py 140 | tokenizer.add_tokens([AddedToken(" <="), AddedToken(" <")]) 141 | 142 | print("Load dataset") 143 | metric, dataset_splits = load_dataset( 144 | data_args=data_args, 145 | model_args=model_args, 146 | data_training_args=data_training_args, 147 | training_args=training_args, 148 | tokenizer=tokenizer, 149 | ) 150 | print("Load dataset") 151 | 152 | if training_args.do_train: 153 | if data_training_args.training_method == 'PT': 154 | training_args.eval_steps = 100*int(dataset_splits.train_split.dataset.num_rows/(training_args.per_device_train_batch_size*training_args.gradient_accumulation_steps)) 155 | else: 156 | training_args.eval_steps = 2*int(dataset_splits.train_split.dataset.num_rows/(training_args.per_device_train_batch_size*training_args.gradient_accumulation_steps)) 157 | training_args.save_steps = training_args.eval_steps*100000 158 | 159 | with PicardLauncher() if picard_args.launch_picard and training_args.local_rank <= 0 else nullcontext(None): 160 | if data_training_args.use_constrained_decoding: 161 | model_cls_wrapper = lambda model_cls: with_picard( 162 | model_cls=model_cls, picard_args=picard_args, tokenizer=tokenizer, schemas=dataset_splits.schemas, stage=data_training_args.stage, 163 | ) 164 | else: 165 | model_cls_wrapper = lambda model_cls: model_cls 166 | 167 | print("Initialize model") 168 | model_ = model_cls_wrapper(AutoModelForSeq2SeqLM).from_pretrained( 169 | model_args.model_name_or_path, 170 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 171 | config=config, 172 | cache_dir=model_args.cache_dir, 173 | revision=model_args.model_revision, 174 | use_auth_token=True if model_args.use_auth_token else None, 175 | ) 176 | if isinstance(model_, T5ForConditionalGeneration): 177 | model_.resize_token_embeddings(len(tokenizer)) 178 | if data_training_args.stage == 'structure': 179 | if data_args.dataset in ["spider", "cosql"]: 180 | prompt_length_list = [60, 15, 15, 60] 181 | elif data_args.dataset in ["geoquery"]: 182 | prompt_length_list = [1, 1, 1, 1] 183 | elif data_training_args.stage == 'content': 184 | prompt_length_list = [60, 15, 15, 60] 185 | if data_training_args.training_method == 'PT': 186 | print(f"-------PT--------") 187 | model = PromptWrapper( 188 | model_, 189 | prompt_length_list=prompt_length_list, 190 | freeze_model=True, 191 | initialize_from_vocab=True, 192 | ) 193 | model.main_input_name = 'input_ids' 194 | elif data_training_args.training_method == 'PFT': 195 | print(f"-------PFT--------") 196 | model = PromptWrapper( 197 | model_, 198 | prompt_length_list=prompt_length_list, 199 | freeze_model=False, 200 | stage=data_training_args.stage, 201 | initial_vectors_path=data_training_args.initial_vectors_path, 202 | initialize_from_pretrain=True, 203 | ) 204 | model.main_input_name = 'input_ids' 205 | else: 206 | print("-------FT--------") 207 | model = model_ 208 | 209 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 210 | logger.warning( 211 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 212 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 213 | ) 214 | 215 | print("Initialize Trainer") 216 | trainer_kwargs = { 217 | "model": model, 218 | "args": training_args, 219 | "use_decomposition": data_training_args.use_decomposition, 220 | "training_method": data_training_args.training_method, 221 | "stage": data_training_args.stage, 222 | "metric": metric, 223 | "train_dataset": dataset_splits.train_split.dataset if training_args.do_train else None, 224 | "eval_dataset": dataset_splits.eval_split.dataset if training_args.do_eval else None, 225 | "eval_examples": dataset_splits.eval_split.examples if training_args.do_eval else None, 226 | "tokenizer": tokenizer, 227 | "data_collator": DataCollatorForSeq2Seq( 228 | tokenizer, 229 | model=model, 230 | label_pad_token_id=(-100 if data_training_args.ignore_pad_token_for_loss else tokenizer.pad_token_id), 231 | pad_to_multiple_of=8 if training_args.fp16 else None, 232 | ), 233 | "ignore_pad_token_for_loss": data_training_args.ignore_pad_token_for_loss, 234 | "target_with_db_id": data_training_args.target_with_db_id, 235 | } 236 | 237 | if data_args.dataset in ["spider"]: 238 | trainer = SpiderTrainer(**trainer_kwargs) 239 | elif data_args.dataset in ["cosql"]: 240 | trainer = CoSQLTrainer(**trainer_kwargs) 241 | elif data_args.dataset in ["geoquery"]: 242 | trainer = GeoQueryTrainer(**trainer_kwargs) 243 | else: 244 | raise NotImplementedError() 245 | 246 | # Training 247 | if training_args.do_train: 248 | logger.info("*** Train ***") 249 | 250 | checkpoint = None 251 | 252 | if training_args.resume_from_checkpoint is not None: 253 | checkpoint = training_args.resume_from_checkpoint 254 | elif last_checkpoint is not None: 255 | checkpoint = last_checkpoint 256 | 257 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 258 | 259 | # Evaluation 260 | if training_args.do_eval: 261 | logger.info("*** Evaluate ***") 262 | 263 | metrics = trainer.evaluate( 264 | max_length=data_training_args.val_max_target_length, 265 | max_time=data_training_args.val_max_time, 266 | num_beams=data_training_args.num_beams, 267 | metric_key_prefix="eval", 268 | ) 269 | metrics["eval_samples"] = dataset_splits.eval_split.dataset.num_rows 270 | 271 | trainer.log_metrics("eval", metrics) 272 | trainer.save_metrics("eval", metrics) 273 | 274 | # Testing 275 | if training_args.do_predict: 276 | logger.info("*** Predict ***") 277 | for section, test_split in dataset_splits.test_splits.items(): 278 | results = trainer.predict( 279 | test_split.dataset, 280 | test_split.examples, 281 | max_length=data_training_args.val_max_target_length, 282 | max_time=data_training_args.val_max_time, 283 | num_beams=data_training_args.num_beams, 284 | metric_key_prefix=section) 285 | metrics = results.metrics 286 | 287 | metrics[f"{section}_samples"] = len(test_split.dataset) 288 | 289 | trainer.log_metrics(section, metrics) 290 | trainer.save_metrics(section, metrics) 291 | 292 | 293 | if __name__ == "__main__": 294 | main() 295 | -------------------------------------------------------------------------------- /src/utils/PT_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from torch import nn 3 | import torch 4 | from transformers.modeling_utils import PreTrainedModel 5 | import numpy as np 6 | 7 | 8 | class PromptWrapper(nn.Module): 9 | def __init__( 10 | self, 11 | model: PreTrainedModel, 12 | prompt_length_list: list = [60, 15, 15, 60], 13 | initial_vectors_path: str = '', 14 | freeze_model: bool = True, 15 | stage: str = 'content', 16 | initialize_from_pretrain: bool = False, 17 | random_range: float = 0.5, 18 | initialize_from_vocab: bool = True 19 | ): 20 | super().__init__() 21 | 22 | self.prompt_length = sum(prompt_length_list) 23 | self.model = model 24 | if freeze_model == True: 25 | for p in model.parameters(): 26 | p.requires_grad = False 27 | if not initialize_from_pretrain: 28 | self.prompt_head = nn.Parameter( 29 | self.initialize_embedding( 30 | model.get_input_embeddings(), 31 | prompt_length_list[0], 32 | random_range, 33 | initialize_from_vocab, 34 | ) 35 | ) 36 | self.prompt_mid1 = nn.Parameter( 37 | self.initialize_embedding( 38 | model.get_input_embeddings(), 39 | prompt_length_list[1], 40 | random_range, 41 | initialize_from_vocab, 42 | ) 43 | ) 44 | self.prompt_mid2 = nn.Parameter( 45 | self.initialize_embedding( 46 | model.get_input_embeddings(), 47 | prompt_length_list[2], 48 | random_range, 49 | initialize_from_vocab, 50 | ) 51 | ) 52 | self.prompt_tail = nn.Parameter( 53 | self.initialize_embedding( 54 | model.get_input_embeddings(), 55 | prompt_length_list[3], 56 | random_range, 57 | initialize_from_vocab, 58 | ) 59 | ) 60 | else: 61 | print(f"initialize from {initial_vectors_path}") 62 | 63 | self.prompt_head = nn.Parameter( 64 | torch.from_numpy(np.load(initial_vectors_path)) 65 | ) 66 | self.prompt_tail = nn.Parameter( 67 | torch.from_numpy(np.load(initial_vectors_path.replace('head', 'tail'))) 68 | ) 69 | self.prompt_mid1 = nn.Parameter( 70 | torch.from_numpy(np.load(initial_vectors_path.replace('head', 'mid1'))) 71 | ) 72 | self.prompt_mid2 = nn.Parameter( 73 | torch.from_numpy(np.load(initial_vectors_path.replace('head', 'mid2'))) 74 | ) 75 | 76 | 77 | def initialize_embedding( 78 | self, 79 | embedding: nn.Embedding, 80 | prompt_length: int = 10, 81 | random_range: float = 0.5, 82 | initialize_from_vocab: bool = True, 83 | initialize_from_keywords: bool = True, 84 | ): 85 | 86 | if initialize_from_vocab: 87 | indices = torch.randint(0, 5000, (prompt_length,)) 88 | return embedding.weight[indices].clone().detach() 89 | 90 | return torch.FloatTensor(prompt_length, embedding.weight.size(1)).uniform_( 91 | -random_range, random_range 92 | ) 93 | 94 | def build_inputs(self, input_ids, attention_mask, labels=None): 95 | batch_size = input_ids.shape[0] 96 | device = input_ids.device 97 | 98 | prompt_length = self.prompt_head.size(0) + self.prompt_mid1.size(0) + self.prompt_mid2.size(0) + self.prompt_tail.size(0) 99 | if prompt_length and attention_mask is not None: 100 | padding = torch.full((batch_size, (prompt_length)), 1).to(device) 101 | attention_mask = torch.cat((padding, attention_mask), dim=1) 102 | 103 | inputs_embeds = self.model.get_input_embeddings()(input_ids) 104 | """ 105 | Input Format: 106 | [prompt_head] Translate the question into sql according to the database: xxx 107 | [prompt_mid1] | question: xxx 108 | [prompt_mid2] | database: xxx 109 | [prompt_tail] 110 | """ 111 | extend_embeds = [] 112 | for idx0 in range(inputs_embeds.shape[0]): 113 | mid1 = None 114 | mid2 = None 115 | end = None 116 | for idx1 in range(inputs_embeds.shape[1]): 117 | if input_ids[idx0][idx1] == 1820 and input_ids[idx0][idx1+1] == 822 and mid1 == None: 118 | mid1 = idx1 119 | elif input_ids[idx0][idx1] == 1820 and input_ids[idx0][idx1+1] == 3501 and mid2 == None: 120 | mid2 = idx1 121 | elif input_ids[idx0][idx1] == 1: 122 | end = idx1 123 | if mid2 == None: 124 | mid2 = end 125 | 126 | extend_embeds.append(torch.cat([self.prompt_head, inputs_embeds[idx0][:mid1], self.prompt_mid1, inputs_embeds[idx0][mid1:mid2], self.prompt_mid2, inputs_embeds[idx0][mid2:end], self.prompt_tail, inputs_embeds[idx0][end:]], 0)) 127 | 128 | 129 | inputs_embeds = torch.stack(extend_embeds,dim=0) 130 | 131 | return inputs_embeds, attention_mask, labels 132 | 133 | def forward(self, input_ids, attention_mask, labels=None, **kwargs): 134 | inputs_embeds, attention_mask, labels = self.build_inputs( 135 | input_ids, 136 | attention_mask, 137 | labels, 138 | ) 139 | 140 | return self.model( 141 | inputs_embeds=inputs_embeds, 142 | attention_mask=attention_mask, 143 | labels=labels, 144 | **kwargs, 145 | ) 146 | 147 | @torch.no_grad() 148 | def generate(self, input_ids=None, attention_mask=None, structures=None, **kwargs): 149 | inputs_embeds, attention_mask, _ = self.build_inputs( 150 | input_ids, 151 | attention_mask, 152 | labels=None, 153 | ) 154 | 155 | model_kwargs = { 156 | "encoder_outputs": self.model.get_encoder()(inputs_embeds=inputs_embeds) 157 | } 158 | 159 | return self.model.generate( 160 | input_ids=None, 161 | use_cache=True, 162 | no_repeat_ngram_size=0, 163 | structures=structures, 164 | **model_kwargs, 165 | **kwargs, 166 | ) 167 | 168 | @property 169 | def config(self): 170 | return self.model.config 171 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruc-datalab/SC-prompt/860491faef9dfb4a711380b23ec2e902b8a7250d/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/args.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, Dict, Callable 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | config_name: Optional[str] = field( 15 | default=None, 16 | metadata={"help": "Pretrained config name or path if not the same as model_name"}, 17 | ) 18 | tokenizer_name: Optional[str] = field( 19 | default=None, 20 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, 21 | ) 22 | cache_dir: Optional[str] = field( 23 | default=None, 24 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 25 | ) 26 | use_fast_tokenizer: bool = field( 27 | default=True, 28 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 29 | ) 30 | model_revision: str = field( 31 | default="main", 32 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 33 | ) 34 | use_auth_token: bool = field( 35 | default=False, 36 | metadata={ 37 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 38 | "with private models)." 39 | }, 40 | ) 41 | 42 | @dataclass 43 | class DataTrainingArguments: 44 | """ 45 | Arguments pertaining to what data we are going to input our model for training and eval. 46 | """ 47 | 48 | overwrite_cache: bool = field( 49 | default=False, 50 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 51 | ) 52 | preprocessing_num_workers: Optional[int] = field( 53 | default=None, 54 | metadata={"help": "The number of processes to use for the preprocessing."}, 55 | ) 56 | max_source_length: Optional[int] = field( 57 | default=512, 58 | metadata={ 59 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 60 | "than this will be truncated, sequences shorter will be padded." 61 | }, 62 | ) 63 | max_target_length: Optional[int] = field( 64 | default=512, 65 | metadata={ 66 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 67 | "than this will be truncated, sequences shorter will be padded." 68 | }, 69 | ) 70 | val_max_target_length: Optional[int] = field( 71 | default=None, 72 | metadata={ 73 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 74 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 75 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 76 | "during ``evaluate`` and ``predict``." 77 | }, 78 | ) 79 | val_max_time: Optional[int] = field( 80 | default=None, 81 | metadata={ 82 | "help": "The maximum allowed time in seconds for generation of one example. This setting can be used to stop " 83 | "generation whenever the full generation exceeds the specified amount of time." 84 | }, 85 | ) 86 | train_samples_ratio: Optional[float] = field( 87 | default=1., 88 | metadata={ 89 | "help": "For few-shot learning" 90 | "value if set." 91 | }, 92 | ) 93 | max_val_samples: Optional[int] = field( 94 | default=None, 95 | metadata={ 96 | "help": "For debugging purposes or quicker training, truncate the number of validation or test examples to this " 97 | "value if set." 98 | }, 99 | ) 100 | use_constrained_decoding: bool = field( 101 | default=True, 102 | metadata={ 103 | "help": "Whether constrained decoding is used." 104 | "which is used during ``structure-stage`` and ``content-stage``." 105 | }, 106 | ) 107 | num_beams: int = field( 108 | default=1, 109 | metadata={ 110 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 111 | "which is used during ``evaluate`` and ``predict``." 112 | }, 113 | ) 114 | num_return_sequences: int = field( 115 | default=1, 116 | metadata={ 117 | "help": "Number of return_sequences to use for evaluation. This argument will be passed to ``model.generate``, " 118 | "which is used during ``evaluate`` and ``predict``." 119 | }, 120 | ) 121 | num_beam_groups: int = field( 122 | default=1, 123 | metadata={ 124 | "help": "Number of beam groups to use for evaluation. This argument will be passed to ``model.generate``, " 125 | "which is used during ``evaluate`` and ``predict``." 126 | }, 127 | ) 128 | diversity_penalty: Optional[float] = field( 129 | default=None, 130 | metadata={ 131 | "help": "Diversity penalty to use for evaluation. This argument will be passed to ``model.generate``, " 132 | "which is used during ``evaluate`` and ``predict``." 133 | }, 134 | ) 135 | num_return_sequences: Optional[int] = field( 136 | default=None, 137 | metadata={ 138 | "help": "The number of sequences to generate during evaluation. This argument will be passed to " 139 | "``model.generate``, which is used during ``evaluate`` and ``predict``." 140 | }, 141 | ) 142 | ignore_pad_token_for_loss: bool = field( 143 | default=True, 144 | metadata={ 145 | "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation or not." 146 | }, 147 | ) 148 | source_prefix: Optional[str] = field( 149 | default=None, 150 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."}, 151 | ) 152 | schema_serialization_type: str = field( 153 | default="peteshaw", 154 | metadata={"help": "Choose between ``verbose`` and ``peteshaw`` schema serialization."}, 155 | ) 156 | schema_serialization_randomized: bool = field( 157 | default=False, 158 | metadata={"help": "Whether or not to randomize the order of tables."}, 159 | ) 160 | schema_serialization_with_db_id: bool = field( 161 | default=True, 162 | metadata={"help": "Whether or not to add the database id to the context. Needed for Picard."}, 163 | ) 164 | schema_serialization_with_prompt: str = field( 165 | default="", 166 | metadata={"help": "Whether or not to use prompt."} 167 | ) 168 | schema_serialization_with_db_content: bool = field( 169 | default=True, 170 | metadata={"help": "Whether or not to use the database content to resolve field matches."}, 171 | ) 172 | normalize_query: bool = field(default=True, metadata={"help": "Whether to normalize the SQL queries."}) 173 | target_with_db_id: bool = field( 174 | default=True, 175 | metadata={"help": "Whether or not to add the database id to the target. Needed for Picard."}, 176 | ) 177 | use_decomposition: bool = field(default=False, metadata={"help": "Whether to use decomposition."}) 178 | stage: str = field(default='structure', metadata={"help": "Training structure prediction module or content prediction module."}) 179 | training_method: str = field(default='FT', metadata={"help": "Training with PT or FT."}) 180 | structure_path: str = field( 181 | default="", 182 | metadata={"help": "the path to the sql structure. only use in the content-fill stage."} 183 | ) 184 | initial_vectors_path: str = field( 185 | default="", 186 | metadata={"help": "the path to the initial learnable vectors. only use in the 'PFT' mode."} 187 | ) 188 | def __post_init__(self): 189 | if self.val_max_target_length is None: 190 | self.val_max_target_length = self.max_target_length 191 | 192 | 193 | @dataclass 194 | class DataArguments: 195 | dataset: str = field( 196 | metadata={"help": "The dataset to be used. Choose between ``spider``, ``squall``, ``cosql``, or ``cosql+spider``, or ``spider_realistic``, or ``spider_syn``, or ``spider_dk``."}, 197 | ) 198 | dataset_paths: Dict[str, str] = field( 199 | default_factory=lambda: { 200 | "spider": "./src/datasets/spider", 201 | "cosql": "./src/datasets/cosql", 202 | "geoquery": "./src/datasets/geoquery", 203 | }, 204 | metadata={"help": "Paths of the dataset modules."}, 205 | ) 206 | metric_config: str = field( 207 | default="both", 208 | metadata={"help": "Choose between ``exact_match``, ``test_suite``, or ``both``."}, 209 | ) 210 | metric_paths: Dict[str, str] = field( 211 | default_factory=lambda: { 212 | "spider": "./src/metrics/spider", 213 | "cosql": "./src/metrics/cosql", 214 | }, 215 | metadata={"help": "Paths of the metric modules."}, 216 | ) 217 | test_suite_db_dir: Optional[str] = field( 218 | default=None, 219 | metadata={"help": "Path to the test-suite databases."}) 220 | data_config_file : Optional[str] = field( 221 | default=None, 222 | metadata={"help": "Path to data configuration file (specifying the database splits)"} 223 | ) 224 | test_sections : Optional[List[str]] = field( 225 | default=None, 226 | metadata={"help": "Sections from the data config to use for testing"} 227 | ) -------------------------------------------------------------------------------- /src/utils/bridge_content_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Encode DB content. 8 | """ 9 | 10 | import difflib 11 | from typing import List, Optional, Tuple 12 | from rapidfuzz import fuzz 13 | import sqlite3 14 | import functools 15 | 16 | # fmt: off 17 | _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', 18 | 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', 19 | "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', 20 | "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', 21 | 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', 22 | 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', 23 | 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', 24 | 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', 25 | 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', 26 | 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', 27 | "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', 28 | 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', 29 | 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', 30 | 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', 31 | 'then', 'did', 'just', "aren't"} 32 | # fmt: on 33 | 34 | _commonwords = {"no", "yes", "many"} 35 | 36 | 37 | def is_number(s: str) -> bool: 38 | try: 39 | float(s.replace(",", "")) 40 | return True 41 | except: 42 | return False 43 | 44 | 45 | def is_stopword(s: str) -> bool: 46 | return s.strip() in _stopwords 47 | 48 | 49 | def is_commonword(s: str) -> bool: 50 | return s.strip() in _commonwords 51 | 52 | 53 | def is_common_db_term(s: str) -> bool: 54 | return s.strip() in ["id"] 55 | 56 | 57 | class Match(object): 58 | def __init__(self, start: int, size: int) -> None: 59 | self.start = start 60 | self.size = size 61 | 62 | 63 | def is_span_separator(c: str) -> bool: 64 | return c in "'\"()`,.?! " 65 | 66 | 67 | def split(s: str) -> List[str]: 68 | return [c.lower() for c in s.strip()] 69 | 70 | 71 | def prefix_match(s1: str, s2: str) -> bool: 72 | i, j = 0, 0 73 | for i in range(len(s1)): 74 | if not is_span_separator(s1[i]): 75 | break 76 | for j in range(len(s2)): 77 | if not is_span_separator(s2[j]): 78 | break 79 | if i < len(s1) and j < len(s2): 80 | return s1[i] == s2[j] 81 | elif i >= len(s1) and j >= len(s2): 82 | return True 83 | else: 84 | return False 85 | 86 | 87 | def get_effective_match_source(s: str, start: int, end: int) -> Match: 88 | _start = -1 89 | 90 | for i in range(start, start - 2, -1): 91 | if i < 0: 92 | _start = i + 1 93 | break 94 | if is_span_separator(s[i]): 95 | _start = i 96 | break 97 | 98 | if _start < 0: 99 | return None 100 | 101 | _end = -1 102 | for i in range(end - 1, end + 3): 103 | if i >= len(s): 104 | _end = i - 1 105 | break 106 | if is_span_separator(s[i]): 107 | _end = i 108 | break 109 | 110 | if _end < 0: 111 | return None 112 | 113 | while _start < len(s) and is_span_separator(s[_start]): 114 | _start += 1 115 | while _end >= 0 and is_span_separator(s[_end]): 116 | _end -= 1 117 | 118 | return Match(_start, _end - _start + 1) 119 | 120 | 121 | def get_matched_entries( 122 | s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 123 | ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: 124 | if not field_values: 125 | return None 126 | 127 | if isinstance(s, str): 128 | n_grams = split(s) 129 | else: 130 | n_grams = s 131 | 132 | matched = dict() 133 | for field_value in field_values: 134 | if not isinstance(field_value, str): 135 | continue 136 | fv_tokens = split(field_value) 137 | sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) 138 | match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) 139 | if match.size > 0: 140 | source_match = get_effective_match_source( 141 | n_grams, match.a, match.a + match.size 142 | ) 143 | if source_match and source_match.size > 1: 144 | match_str = field_value[match.b : match.b + match.size] 145 | source_match_str = s[ 146 | source_match.start : source_match.start + source_match.size 147 | ] 148 | c_match_str = match_str.lower().strip() 149 | c_source_match_str = source_match_str.lower().strip() 150 | c_field_value = field_value.lower().strip() 151 | if ( 152 | c_match_str 153 | and not is_number(c_match_str) 154 | and not is_common_db_term(c_match_str) 155 | ): 156 | if ( 157 | is_stopword(c_match_str) 158 | or is_stopword(c_source_match_str) 159 | or is_stopword(c_field_value) 160 | ): 161 | continue 162 | if c_source_match_str.endswith(c_match_str + "'s"): 163 | match_score = 1.0 164 | else: 165 | if prefix_match(c_field_value, c_source_match_str): 166 | match_score = ( 167 | fuzz.ratio(c_field_value, c_source_match_str) / 100 168 | ) 169 | else: 170 | match_score = 0 171 | if ( 172 | is_commonword(c_match_str) 173 | or is_commonword(c_source_match_str) 174 | or is_commonword(c_field_value) 175 | ) and match_score < 1: 176 | continue 177 | s_match_score = match_score 178 | if match_score >= m_theta and s_match_score >= s_theta: 179 | if field_value.isupper() and match_score * s_match_score < 1: 180 | continue 181 | matched[match_str] = ( 182 | field_value, 183 | source_match_str, 184 | match_score, 185 | s_match_score, 186 | match.size, 187 | ) 188 | 189 | if not matched: 190 | return None 191 | else: 192 | return sorted( 193 | matched.items(), 194 | key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), 195 | reverse=True, 196 | ) 197 | 198 | 199 | @functools.lru_cache(maxsize=1000, typed=False) 200 | def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: 201 | fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) 202 | try: 203 | conn = sqlite3.connect(db_path) 204 | conn.text_factory = bytes 205 | c = conn.cursor() 206 | c.execute(fetch_sql) 207 | picklist = set() 208 | for x in c.fetchall(): 209 | if isinstance(x[0], str): 210 | picklist.add(x[0].encode("utf-8")) 211 | elif isinstance(x[0], bytes): 212 | try: 213 | picklist.add(x[0].decode("utf-8")) 214 | except UnicodeDecodeError: 215 | picklist.add(x[0].decode("latin-1")) 216 | else: 217 | picklist.add(x[0]) 218 | picklist = list(picklist) 219 | finally: 220 | conn.close() 221 | return picklist 222 | 223 | 224 | def get_database_matches( 225 | question: str, 226 | table_name: str, 227 | column_name: str, 228 | db_path: str, 229 | top_k_matches: int = 2, 230 | match_threshold: float = 0.85, 231 | ) -> List[str]: 232 | picklist = get_column_picklist( 233 | table_name=table_name, column_name=column_name, db_path=db_path 234 | ) 235 | matches = [] 236 | if picklist and isinstance(picklist[0], str): 237 | matched_entries = get_matched_entries( 238 | s=question, 239 | field_values=picklist, 240 | m_theta=match_threshold, 241 | s_theta=match_threshold, 242 | ) 243 | if matched_entries: 244 | num_values_inserted = 0 245 | for _match_str, ( 246 | field_value, 247 | _s_match_str, 248 | match_score, 249 | s_match_score, 250 | _match_size, 251 | ) in matched_entries: 252 | if "name" in column_name and match_score * s_match_score < 1: 253 | continue 254 | if table_name != "sqlite_sequence": # Spider database artifact 255 | matches.append(field_value) 256 | num_values_inserted += 1 257 | if num_values_inserted >= top_k_matches: 258 | break 259 | return matches 260 | -------------------------------------------------------------------------------- /src/utils/cosql.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from typing import Optional, List 4 | from datasets.arrow_dataset import Dataset 5 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 6 | from .args import * 7 | from .dataset import normalize, serialize_schema, combine_SC 8 | from .trainer import Seq2SeqTrainer, EvalPrediction 9 | from .process_sql import get_schema, Schema, get_sql 10 | import re 11 | import shlex 12 | import random 13 | import os 14 | 15 | sql_clauses = ['select', 'from', 'where', 'group', 'having', 'order', 'limit', 'intersect', 'union', 'except'] 16 | sql_ops_space = ['>', '<', '=', 'like', 'between', 'and', 'or', 'not', 'in', ')', 'by', 'distinct', '+', '>=', '<='] 17 | sql_ops_no_space = ['count', 'avg', 'sum', 'max', 'min', '(', '!', 'desc', 'asc'] 18 | sql_marks = [','] 19 | pattern_1 = re.compile('(\(|\))') 20 | pattern_4 = re.compile('(\()') 21 | pattern_5 = re.compile('(\))') 22 | pattern_2 = re.compile('(,)') 23 | pattern_3 = re.compile('(>=|<=|>|<|=)') 24 | 25 | def lower_( 26 | word: str, 27 | ) -> str: 28 | if '"' in word or "'" in word: 29 | return word 30 | else: 31 | return word.lower() 32 | 33 | def tok_process( 34 | toks: list, 35 | ) -> list: 36 | processed_tok_list = [] 37 | i = 0 38 | while i < len(toks): 39 | if toks[i] == "``" and toks[i+2] == "''": 40 | temp = f'"{toks[i+1]}"' 41 | processed_tok_list.append(temp) 42 | i += 3 43 | continue 44 | else: 45 | processed_tok_list.append(toks[i]) 46 | i += 1 47 | return [lower_(x) for x in processed_tok_list] 48 | 49 | def cosql_get_input( 50 | utterances: List[str], 51 | serialized_schema: str, 52 | prefix: str, 53 | sep: str = " | ", 54 | ) -> str: 55 | utterances = (utterance.strip() for utterance in utterances) 56 | serialize_utterances = sep.join(utterances) 57 | return prefix + serialize_utterances + " | " + serialized_schema.strip() 58 | 59 | 60 | def cosql_get_target( 61 | query: str, 62 | db_id: str, 63 | normalize_query: bool, 64 | target_with_db_id: bool, 65 | ) -> str: 66 | _normalize = normalize if normalize_query else (lambda x: x) 67 | return f"{db_id} | {_normalize(query)}" if target_with_db_id else _normalize(query) 68 | 69 | 70 | def cosql_add_serialized_schema( 71 | ex: dict, 72 | mode: str, 73 | data_training_args: DataTrainingArguments, 74 | ) -> dict: 75 | serialized_schema = serialize_schema( 76 | question=" | ".join(ex["utterances"]), 77 | db_path=ex["db_path"], 78 | db_id=ex["db_id"], 79 | db_column_names=ex["db_column_names"], 80 | db_table_names=ex["db_table_names"], 81 | schema_serialization_type=data_training_args.schema_serialization_type, 82 | schema_serialization_randomized=data_training_args.schema_serialization_randomized, 83 | schema_serialization_with_db_id=data_training_args.schema_serialization_with_db_id, 84 | schema_serialization_with_db_content=data_training_args.schema_serialization_with_db_content, 85 | normalize_query=data_training_args.normalize_query, 86 | ) 87 | return {"serialized_schema": serialized_schema} 88 | 89 | 90 | def cosql_pre_process_function( 91 | batch: dict, 92 | max_source_length: Optional[int], 93 | max_target_length: Optional[int], 94 | mode: Optional[str], 95 | data_training_args: DataTrainingArguments, 96 | tokenizer: PreTrainedTokenizerBase, 97 | ) -> dict: 98 | prefix = data_training_args.source_prefix if data_training_args.source_prefix is not None else "question: " 99 | if data_training_args.use_decomposition: 100 | inputs = [] 101 | targets = [] 102 | if data_training_args.stage == 'content': 103 | eval_format_list = [] 104 | with open(data_training_args.structure_path) as f: 105 | info = json.load(f) 106 | for item in info: 107 | eval_format_list.append(item['prediction']) 108 | print(f"load {len(eval_format_list)} eval_formats") 109 | if len(batch['utterances']) == 1000: 110 | count = 0 111 | else: 112 | count = 1000 113 | for question, serialized_schema, db_id, query, db_column_names in zip(batch["utterances"], batch["serialized_schema"], batch["db_id"], batch["query"], batch["db_column_names"]): 114 | input_str = cosql_get_input(utterances=question, serialized_schema=serialized_schema, prefix=prefix) 115 | column_names = [x.lower() for x in db_column_names['column_name']] 116 | #query_toks = query.split() 117 | lex = shlex.shlex(query) 118 | lex.whitespace = ' ' 119 | lex.quotes=['"', "'"] 120 | lex.whitespace_split = True 121 | query_toks = list(lex) 122 | query_tok_list = tok_process(query_toks) 123 | for idx, tok in enumerate(query_tok_list): 124 | if '"' in tok or "'" in tok: 125 | continue 126 | if len(tok) > 1 and ',' in tok and tok not in column_names and query_tok_list[idx-1] not in sql_ops_space: 127 | res = pattern_2.split(tok) 128 | query_tok_list[idx:idx+1] = res 129 | if '(' in query_tok_list[idx] and ')' in query_tok_list[idx] and ('sum' in query_tok_list[idx] or 'count' in query_tok_list[idx] or 'avg' in query_tok_list[idx] or 'max' in query_tok_list[idx] or 'min' in query_tok_list[idx]): 130 | res = pattern_1.split(query_tok_list[idx]) 131 | query_tok_list[idx:idx+1] = res 132 | elif '(' in query_tok_list[idx] and 'select' in query_tok_list[idx]: 133 | res = pattern_4.split(query_tok_list[idx]) 134 | query_tok_list[idx:idx+1] = res 135 | elif len(query_tok_list[idx]) > 1 and ')' == query_tok_list[idx][-1]: 136 | res = pattern_5.split(query_tok_list[idx]) 137 | query_tok_list[idx:idx+1] = res 138 | if ('>' in query_tok_list[idx] or '<' in query_tok_list[idx] or '>=' in query_tok_list[idx] or '<=' in query_tok_list[idx] or '=' in query_tok_list[idx]) and query_tok_list[idx][0] not in ['>', '<', '=']: 139 | res = pattern_3.split(query_tok_list[idx]) 140 | query_tok_list[idx:idx+1] = res 141 | for idx, tok in enumerate(query_tok_list): 142 | if tok == '': 143 | del query_tok_list[idx] 144 | #print(query) 145 | sub_query_format_list = [] 146 | content_label = '' 147 | sub_query_list = [] 148 | select_from_record = [] 149 | sub_query_format = '' 150 | sub_query = '' 151 | last_tok = None 152 | select_num = 0 153 | left_bracket = 0 154 | right_bracket = 0 155 | if query_tok_list[-1] == ';': 156 | query_tok_list = query_tok_list[:-1] 157 | for idx, query_tok in enumerate(query_tok_list): 158 | if query_tok in sql_clauses: 159 | if query_tok == 'select': 160 | select_num += 1 161 | elif (query_tok == 'group' or query_tok == 'order') and query_tok_list[idx+1] != 'by': 162 | if query_tok in column_names: 163 | if idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in [',', ')']: 164 | sub_query_format += '[col]' 165 | else: 166 | sub_query_format += '[col] ' 167 | content_label += '[col] ' + query_tok + ' ' 168 | sub_query += query_tok + ' ' 169 | else: 170 | print("error:", query_tok) 171 | continue 172 | 173 | if sub_query_format != '' and sub_query != '': 174 | if 'select' in sub_query_format: 175 | select_from_record.append(1) 176 | elif 'from' in sub_query_format: 177 | select_from_record.append(2) 178 | else: 179 | select_from_record.append(0) 180 | sub_query_format_list.append(sub_query_format) 181 | sub_query_list.append(sub_query) 182 | sub_query_format = '' 183 | sub_query = '' 184 | if query_tok == 'from': 185 | sub_query_format += 'from [tab]' 186 | content_label += '[tab] ' 187 | else: 188 | sub_query_format += query_tok + ' ' 189 | last_tok = 'sql_clauses' 190 | sub_query += query_tok + ' ' 191 | elif sub_query_format == 'from [tab]': 192 | last_tok = 'from [tab]' 193 | sub_query += query_tok + ' ' 194 | if query_tok not in [')', '(']: 195 | content_label += query_tok + ' ' 196 | continue 197 | elif query_tok in sql_ops_space: 198 | if query_tok == ')': 199 | right_bracket += 1 200 | if ((query_tok == '>' or query_tok == '<') and query_tok_list[idx+1] == '=') or (query_tok == ')' and (idx + 1 == len(query_tok_list) or query_tok_list[idx+1] == ',')): 201 | # >= or <= ), 202 | sub_query_format += query_tok 203 | sub_query += query_tok 204 | else: 205 | sub_query_format += query_tok + ' ' 206 | sub_query += query_tok + ' ' 207 | last_tok = 'op' 208 | elif query_tok in sql_ops_no_space: 209 | if query_tok == '(': 210 | left_bracket += 1 211 | sub_query_format += query_tok 212 | sub_query += query_tok 213 | elif query_tok in column_names or '.' in query_tok: 214 | if last_tok == 'val': 215 | content_label += query_tok + ' ' 216 | continue 217 | if idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in [',', ')']: 218 | sub_query_format += '[col]' 219 | sub_query += query_tok 220 | else: 221 | sub_query_format += '[col] ' 222 | sub_query += query_tok + ' ' 223 | content_label += '[col] ' + query_tok + ' ' 224 | last_tok = 'col' 225 | elif query_tok in sql_marks: 226 | sub_query_format += query_tok + ' ' 227 | sub_query += query_tok + ' ' 228 | last_tok = 'mark' 229 | else: 230 | if last_tok != 'val': 231 | sub_query_format += '[val] ' 232 | content_label += '[val] ' 233 | if query_tok == '``': 234 | sub_query += '"' 235 | content_label += '"' 236 | elif query_tok == "''": 237 | sub_query += '" ' 238 | content_label += '" ' 239 | elif query_tok == "'": 240 | sub_query += "' " 241 | content_label += "' " 242 | elif last_tok == 'val' and (idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in ["'", '"', '``', "''"]): 243 | sub_query += query_tok 244 | content_label += query_tok 245 | else: 246 | sub_query += query_tok + ' ' 247 | content_label += query_tok + ' ' 248 | last_tok = 'val' 249 | 250 | if select_num > 1 and left_bracket > right_bracket: 251 | sub_query_format += ')' 252 | if 'select' in sub_query_format: 253 | select_from_record.append(1) 254 | elif 'from' in sub_query_format: 255 | select_from_record.append(2) 256 | else: 257 | select_from_record.append(0) 258 | sub_query_format_list.append(sub_query_format) 259 | sub_query_list.append(sub_query) 260 | if data_training_args.stage == 'structure': 261 | structure = normalize(' '.join(sub_query_format_list)) 262 | inputs.append(data_training_args.schema_serialization_with_prompt + ' | ' + input_str) 263 | target = cosql_get_target( 264 | query=' '.join(sub_query_format_list), 265 | db_id=db_id, 266 | normalize_query=True, 267 | target_with_db_id=False, 268 | ) 269 | targets.append(target) 270 | 271 | elif data_training_args.stage == 'content': 272 | if mode == 'eval': 273 | input_str = data_training_args.schema_serialization_with_prompt + eval_format_list[count] + ' | ' + input_str 274 | else: 275 | input_str = data_training_args.schema_serialization_with_prompt + ' '.join(sub_query_format_list) + ' | ' + input_str 276 | inputs.append(input_str) 277 | target = content_label 278 | targets.append(target) 279 | count += 1 280 | else: 281 | inputs = [ 282 | cosql_get_input(utterances=utterances, serialized_schema=serialized_schema, prefix=prefix) 283 | for utterances, serialized_schema in zip(batch["utterances"], batch["serialized_schema"]) 284 | ] 285 | targets = [ 286 | cosql_get_target( 287 | query=query, 288 | db_id=db_id, 289 | normalize_query=data_training_args.normalize_query, 290 | target_with_db_id=data_training_args.target_with_db_id, 291 | ) 292 | for db_id, query in zip(batch["db_id"], batch["query"]) 293 | ] 294 | print(f"{mode}: {len(inputs)}") 295 | 296 | model_inputs: dict = tokenizer( 297 | inputs, 298 | max_length=max_source_length, 299 | padding=False, 300 | truncation=True, 301 | return_overflowing_tokens=False, 302 | ) 303 | 304 | # Setup the tokenizer for targets 305 | with tokenizer.as_target_tokenizer(): 306 | labels = tokenizer( 307 | targets, 308 | max_length=max_target_length, 309 | padding=False, 310 | truncation=True, 311 | return_overflowing_tokens=False, 312 | ) 313 | 314 | model_inputs["labels"] = labels["input_ids"] 315 | return model_inputs 316 | 317 | 318 | class CoSQLTrainer(Seq2SeqTrainer): 319 | def _post_process_function( 320 | self, examples: Dataset, features: Dataset, predictions: np.ndarray, stage: str 321 | ) -> EvalPrediction: 322 | inputs = self.tokenizer.batch_decode([f["input_ids"] for f in features], skip_special_tokens=True) 323 | label_ids = [f["labels"] for f in features] 324 | if self.ignore_pad_token_for_loss: 325 | # Replace -100 in the labels as we can't decode them. 326 | _label_ids = np.where(label_ids != -100, label_ids, self.tokenizer.pad_token_id) 327 | decoded_label_ids = self.tokenizer.batch_decode(_label_ids, skip_special_tokens=True) 328 | metas = [ 329 | { 330 | "query": x["query"], 331 | "utterances": x["utterances"], 332 | "turn_idx": x["turn_idx"], 333 | "context": context, 334 | "label": label, 335 | "db_id": x["db_id"], 336 | "db_path": x["db_path"], 337 | "db_table_names": x["db_table_names"], 338 | "db_column_names": x["db_column_names"], 339 | "db_foreign_keys": x["db_foreign_keys"], 340 | } 341 | for x, context, label in zip(examples, inputs, decoded_label_ids) 342 | ] 343 | predictions = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) 344 | 345 | assert len(metas) == len(predictions) 346 | if self.stage == 'content': 347 | final_pred_sqls = [] 348 | hypotheses_path = os.path.join(self.args.output_dir, "hypotheses.json") 349 | if os.path.exists(hypotheses_path): 350 | # sentence-level check 351 | with open(hypotheses_path) as f: 352 | hypotheses = json.load(f) 353 | for idx, item in enumerate(hypotheses): 354 | db_id, structure = item["structure"].split(" | ") 355 | db = os.path.join(metas[idx]["db_path"], db_id, f'{db_id}.sqlite') 356 | schema = Schema(get_schema(db)) 357 | final_pred_sql = None 358 | for hypothesis in item["topk_preds"]: 359 | try: 360 | pred_sql = combine_SC(content=hypothesis, input="", structure=structure) 361 | parse_sql = get_sql(schema, pred_sql) 362 | final_pred_sql = pred_sql 363 | break 364 | except: 365 | continue 366 | if final_pred_sql == None: 367 | # default to the first one 368 | final_pred_sql = combine_SC(content=item["topk_preds"][0], input="", structure=structure) 369 | final_pred_sqls.append(final_pred_sql) 370 | 371 | os.remove(hypotheses_path) 372 | else: 373 | for pred_content, meta in zip(predictions, metas): 374 | final_pred_sqls.append(combine_SC(pred_content, meta['context'])) 375 | # write predict sql 376 | with open(f"{self.args.output_dir}/predict_sql.txt", "w") as f: 377 | for final_pred_sql in final_pred_sqls: 378 | f.write(final_pred_sql+"\n") 379 | 380 | with open(f"{self.args.output_dir}/content_{stage}.json", "w") as f: 381 | json.dump( 382 | [dict(**{"input": meta['context']}, **{"prediction": prediction}, **{"label": label}, **{"score": prediction==label}, **{"pred_sql": final_pred_sql}, **{"gold_sql": meta['query'], **{"turn_idx": meta["turn_idx"]}}) for meta, prediction, final_pred_sql, label in zip(metas, predictions, final_pred_sqls, decoded_label_ids)], 383 | f, 384 | indent=4, 385 | ) 386 | return EvalPrediction(predictions=final_pred_sqls, label_ids=decoded_label_ids, metas=metas) 387 | elif self.stage == 'structure': 388 | for idx in range(len(predictions)): 389 | if 'before' in predictions[idx]: 390 | predictions[idx] = predictions[idx].replace('before', '<') 391 | if 'after' in predictions[idx]: 392 | predictions[idx] = predictions[idx].replace('after', '>') 393 | return EvalPrediction(predictions=predictions, label_ids=decoded_label_ids, metas=metas) 394 | 395 | def _compute_metrics(self, eval_prediction: EvalPrediction) -> dict: 396 | predictions, label_ids, metas = eval_prediction 397 | if self.target_with_db_id: 398 | # Remove database id from all predictions 399 | predictions = [pred.split("|", 1)[-1].strip() for pred in predictions] 400 | 401 | references = metas 402 | if self.stage == 'structure': 403 | accuracy = [] 404 | accuracy.extend( 405 | ( 406 | pred == actual 407 | for pred, actual in zip(predictions, label_ids) 408 | ) 409 | ) 410 | eval_metric = np.mean(accuracy) 411 | test_suite = dict() 412 | if eval_metric >= self.best_acc: 413 | with open(f"{self.args.output_dir}/structure.json", "w") as f: 414 | json.dump( 415 | [dict(**{"input": meta['context']}, **{"prediction": prediction}, **{"label": label}, **{"score": prediction==label}) for meta, prediction, label in zip(metas, predictions, label_ids)], 416 | f, 417 | indent=4, 418 | ) 419 | return {**{"exact_match": eval_metric}, **test_suite} 420 | elif self.stage == 'content': 421 | return self.metric.compute(predictions=predictions, references=references) 422 | else: 423 | raise NotImplementedError() 424 | -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict, Callable 2 | from dataclasses import dataclass, field 3 | from datasets.dataset_dict import DatasetDict 4 | from datasets.arrow_dataset import Dataset 5 | from transformers.training_args import TrainingArguments 6 | from .bridge_content_encoder import get_database_matches 7 | import re 8 | import random 9 | import json 10 | from .args import * 11 | 12 | @dataclass 13 | class TrainSplit(object): 14 | dataset: Dataset 15 | schemas: Dict[str, dict] 16 | 17 | 18 | @dataclass 19 | class EvalSplit(object): 20 | dataset: Dataset 21 | examples: Dataset 22 | schemas: Dict[str, dict] 23 | 24 | 25 | @dataclass 26 | class DatasetSplits(object): 27 | train_split: Optional[TrainSplit] 28 | eval_split: Optional[EvalSplit] 29 | test_splits: Optional[Dict[str, EvalSplit]] 30 | schemas: Dict[str, dict] 31 | 32 | 33 | def _get_schemas(examples: Dataset) -> Dict[str, dict]: 34 | schemas: Dict[str, dict] = dict() 35 | for ex in examples: 36 | if ex["db_id"] not in schemas: 37 | schemas[ex["db_id"]] = { 38 | "db_table_names": ex["db_table_names"], 39 | "db_column_names": ex["db_column_names"], 40 | "db_column_types": ex["db_column_types"], 41 | "db_primary_keys": ex["db_primary_keys"], 42 | "db_foreign_keys": ex["db_foreign_keys"], 43 | } 44 | return schemas 45 | 46 | def _prepare_train_split( 47 | dataset: Dataset, 48 | data_training_args: DataTrainingArguments, 49 | data_args: DataArguments, 50 | add_serialized_schema: Callable[[dict], dict], 51 | pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], 52 | ) -> TrainSplit: 53 | 54 | if data_args.dataset in ['']: 55 | schemas = _get_schemas_geoquery(examples=dataset) 56 | else: 57 | schemas = _get_schemas(examples=dataset) 58 | dataset = dataset.map( 59 | lambda ex: add_serialized_schema( 60 | ex=ex, 61 | mode='train'), 62 | batched=False, 63 | num_proc=data_training_args.preprocessing_num_workers, 64 | load_from_cache_file=True, 65 | ) 66 | if data_training_args.train_samples_ratio != 1.0: 67 | if data_args.dataset in ['geoquery']: 68 | if data_training_args.train_samples_ratio == 0.094: 69 | indexs = [19,514,375,341,274,492,221,515,360,58,413,418,333,487,28,122,344,208,475,108,264,155,0,23,4,73,129,27,61,8,74,469,500,396,362,430,17,203,171,33,139,80,503,206,243,5,486,423,244,130] 70 | else: 71 | indexs = random.sample(range(536), int(dataset.num_rows*data_training_args.train_samples_ratio)) 72 | print(indexs) 73 | print(f"use {len(set(indexs))} training samples.") 74 | dataset = dataset.select(indexs) 75 | else: 76 | dataset = dataset.select(range(int(dataset.num_rows*data_training_args.train_samples_ratio))) 77 | column_names = dataset.column_names 78 | dataset = dataset.map( 79 | lambda batch: pre_process_function( 80 | batch=batch, 81 | max_source_length=data_training_args.max_source_length, 82 | max_target_length=data_training_args.max_target_length, 83 | mode='train', 84 | ), 85 | batched=True, 86 | num_proc=data_training_args.preprocessing_num_workers, 87 | remove_columns=column_names, 88 | load_from_cache_file=False, 89 | ) 90 | return TrainSplit(dataset=dataset, schemas=schemas) 91 | 92 | 93 | def _prepare_eval_split( 94 | dataset: Dataset, 95 | data_training_args: DataTrainingArguments, 96 | data_args: DataArguments, 97 | add_serialized_schema: Callable[[dict], dict], 98 | pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], 99 | ) -> EvalSplit: 100 | 101 | eval_examples = dataset 102 | if data_args.dataset in ['']: 103 | schemas = _get_schemas_geoquery(examples=dataset) 104 | else: 105 | schemas = _get_schemas(examples=eval_examples) 106 | eval_dataset = eval_examples.map( 107 | lambda ex: add_serialized_schema( 108 | ex=ex, 109 | mode='eval'), 110 | batched=False, 111 | num_proc=data_training_args.preprocessing_num_workers, 112 | load_from_cache_file=False, 113 | ) 114 | if data_training_args.max_val_samples is not None: 115 | eval_dataset = eval_dataset.select(range(data_training_args.max_val_samples)) 116 | column_names = eval_dataset.column_names 117 | eval_dataset = eval_dataset.map( 118 | lambda batch: pre_process_function( 119 | batch=batch, 120 | max_source_length=data_training_args.max_source_length, 121 | max_target_length=data_training_args.val_max_target_length, 122 | mode='eval', 123 | ), 124 | batched=True, 125 | num_proc=data_training_args.preprocessing_num_workers, 126 | remove_columns=column_names, 127 | load_from_cache_file=False, 128 | ) 129 | return EvalSplit(dataset=eval_dataset, examples=eval_examples, schemas=schemas) 130 | 131 | 132 | def prepare_splits( 133 | dataset_dict: DatasetDict, 134 | data_args: DataArguments, 135 | training_args: TrainingArguments, 136 | data_training_args: DataTrainingArguments, 137 | add_serialized_schema: Callable[[dict], dict], 138 | pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], 139 | ) -> DatasetSplits: 140 | train_split, eval_split, test_splits = None, None, None 141 | if training_args.do_train: 142 | train_split = _prepare_train_split( 143 | dataset_dict["train"], 144 | data_training_args=data_training_args, 145 | data_args=data_args, 146 | add_serialized_schema=add_serialized_schema, 147 | pre_process_function=pre_process_function, 148 | ) 149 | if training_args.do_eval: 150 | eval_split = _prepare_eval_split( 151 | dataset_dict["validation"], 152 | data_args=data_args, 153 | data_training_args=data_training_args, 154 | add_serialized_schema=add_serialized_schema, 155 | pre_process_function=pre_process_function, 156 | ) 157 | if training_args.do_predict: 158 | eval_split = _prepare_eval_split( 159 | dataset_dict["test"], 160 | data_args=data_args, 161 | data_training_args=data_training_args, 162 | add_serialized_schema=add_serialized_schema, 163 | pre_process_function=pre_process_function, 164 | ) 165 | schemas = { 166 | **(train_split.schemas if train_split is not None else {}), 167 | **(eval_split.schemas if eval_split is not None else {}), 168 | **(test_split_schemas if test_splits is not None else {}), 169 | } 170 | 171 | return DatasetSplits( 172 | train_split=train_split, 173 | eval_split=eval_split, 174 | test_splits=test_splits, 175 | schemas=schemas 176 | ) 177 | 178 | 179 | def normalize(query: str) -> str: 180 | def comma_fix(s): 181 | # Remove spaces in front of commas 182 | return s.replace(" , ", ", ") 183 | 184 | def white_space_fix(s): 185 | # Remove double and triple spaces 186 | return " ".join(s.split()) 187 | 188 | def lower(s): 189 | # Convert everything except text between (single or double) quotation marks to lower case 190 | return re.sub(r"\b(? str: 207 | if schema_serialization_type == "verbose": 208 | db_id_str = "database: {db_id}. " 209 | table_sep = ". " 210 | table_str = "table: {table}. columns: {columns}" 211 | column_sep = ", " 212 | column_str_with_values = "{column} ({values})" 213 | column_str_without_values = "{column}" 214 | value_sep = ", " 215 | elif schema_serialization_type == "peteshaw": 216 | db_id_str = " | {db_id}" 217 | table_sep = "" 218 | table_str = " | {table} : {columns}" 219 | column_sep = " , " 220 | column_str_with_values = "{column} ( {values} )" 221 | column_str_without_values = "{column}" 222 | value_sep = " , " 223 | else: 224 | raise NotImplementedError 225 | def get_column_str(table_name: str, column_name: str) -> str: 226 | column_name_str = column_name.lower() if normalize_query else column_name 227 | if schema_serialization_with_db_content: 228 | matches = get_database_matches( 229 | question=question, 230 | table_name=table_name, 231 | column_name=column_name, 232 | db_path=(db_path + "/" + db_id + "/" + db_id + ".sqlite"), 233 | ) 234 | if matches: 235 | string = column_str_with_values.format(column=column_name_str, values=value_sep.join(matches)) 236 | return string 237 | else: 238 | return column_str_without_values.format(column=column_name_str) 239 | else: 240 | return column_str_without_values.format(column=column_name_str) 241 | 242 | tables = [ 243 | table_str.format( 244 | table=table_name.lower() if normalize_query else table_name, 245 | columns=column_sep.join( 246 | map( 247 | lambda y: get_column_str(table_name=table_name, column_name=y[1]), 248 | filter( 249 | lambda y: y[0] == table_id, 250 | zip( 251 | db_column_names["table_id"], 252 | db_column_names["column_name"], 253 | ), 254 | ), 255 | ) 256 | ), 257 | ) 258 | for table_id, table_name in enumerate(db_table_names) 259 | ] 260 | 261 | reorder_tables = [] 262 | for table in tables: 263 | if '(' in table: 264 | reorder_tables = [table] + reorder_tables 265 | else: 266 | reorder_tables.append(table) 267 | 268 | tables = reorder_tables 269 | #if schema_serialization_randomized: 270 | # random.shuffle(tables) 271 | if schema_serialization_with_db_id: 272 | serialized_schema = db_id_str.format(db_id=db_id) + table_sep.join(tables) 273 | else: 274 | serialized_schema = 'database: ' + table_sep.join(tables) 275 | 276 | return serialized_schema 277 | 278 | def combine_SC(content, input, structure=None): 279 | if structure == None: 280 | structure = input.replace("Translate the question into sql according to the database:", "") 281 | end_index = structure.index(" | question:") 282 | structure = structure[:end_index] 283 | col_num = structure.count('[col]') 284 | tab_num = structure.count('[tab]') 285 | val_num = structure.count('[val]') 286 | if (content.count('[col]') != col_num) or (content.count('[tab]') != tab_num) or (content.count('[val]') != val_num): 287 | return structure 288 | 289 | content_dict = {"[col]": [], "[tab]": [], "[val]": []} 290 | tok = None 291 | temp_str = '' 292 | i = 0 293 | while i < len(content): 294 | if content[i] == '[' and i+4 < len(content) and content[i+4] == ']' and (content[i:i+5] in ['[col]', '[tab]', '[val]']): 295 | if tok != None: 296 | content_dict[tok].append(temp_str.strip()) 297 | tok = content[i:i+5] 298 | temp_str = '' 299 | i += 6 300 | continue 301 | temp_str += content[i] 302 | i += 1 303 | if tok != None: 304 | content_dict[tok].append(temp_str.strip()) 305 | 306 | pred_sql = structure 307 | # replace [col] 308 | end_index = 0 309 | for i in range(col_num): 310 | begin_index = pred_sql[end_index:].index('[col]') + end_index 311 | pred_sql = pred_sql[:begin_index] + content_dict['[col]'][i] + pred_sql[begin_index+5:] 312 | end_index = begin_index + len(content_dict['[col]'][i]) + 1 313 | 314 | # replace [tab] 315 | end_index = 0 316 | for i in range(tab_num): 317 | begin_index = pred_sql[end_index:].index('[tab]') + end_index 318 | pred_sql = pred_sql[:begin_index] + content_dict['[tab]'][i] + pred_sql[begin_index+5:] 319 | end_index = begin_index + len(content_dict['[tab]'][i]) + 1 320 | 321 | # replace [val] 322 | end_index = 0 323 | for i in range(val_num): 324 | begin_index = pred_sql[end_index:].index('[val]') + end_index 325 | pred_sql = pred_sql[:begin_index] + content_dict['[val]'][i] + pred_sql[begin_index+5:] 326 | end_index = begin_index + len(content_dict['[val]'][i]) + 1 327 | if pred_sql[0] == ' ': 328 | pred_sql = pred_sql[1:] 329 | return pred_sql 330 | 331 | -------------------------------------------------------------------------------- /src/utils/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Callable, Tuple 3 | import logging 4 | import datasets.load 5 | from datasets.dataset_dict import DatasetDict 6 | from datasets.metric import Metric 7 | from datasets.arrow_dataset import Dataset, concatenate_datasets 8 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 9 | from transformers.training_args import TrainingArguments 10 | from .args import * 11 | from .dataset import ( 12 | DatasetSplits, 13 | TrainSplit, 14 | _prepare_train_split, 15 | prepare_splits, 16 | ) 17 | from .spider import spider_add_serialized_schema, spider_pre_process_function 18 | from .geoquery import geoquery_add_serialized_schema, geoquery_pre_process_function 19 | from .cosql import cosql_add_serialized_schema, cosql_pre_process_function 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | 25 | def _log_duplicate_count(dataset: Dataset, dataset_name: str, split: str) -> None: 26 | d = dataset.to_dict() 27 | d_t = [tuple((k, tuple(v)) for k, v in zip(d.keys(), vs)) for vs in zip(*d.values())] 28 | d_t_ = set(d_t) 29 | num_examples = len(d_t) 30 | duplicate_count = num_examples - len(d_t_) 31 | if duplicate_count > 0: 32 | logger.warning( 33 | f"The split ``{split}`` of the dataset ``{dataset_name}`` contains {duplicate_count} duplicates out of {num_examples} examples" 34 | ) 35 | 36 | 37 | def load_dataset( 38 | data_args: DataArguments, 39 | model_args: ModelArguments, 40 | data_training_args: DataTrainingArguments, 41 | training_args: TrainingArguments, 42 | tokenizer: PreTrainedTokenizerFast, 43 | ) -> Tuple[Metric, DatasetSplits]: 44 | _spider_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset( 45 | path=data_args.dataset_paths["spider"], cache_dir=model_args.cache_dir 46 | ) 47 | _spider_metric: Callable[[], Metric] = lambda: datasets.load.load_metric( 48 | path=data_args.metric_paths["spider"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir 49 | ) 50 | _spider_add_serialized_schema = lambda ex, mode: spider_add_serialized_schema( 51 | ex=ex, 52 | mode=mode, 53 | data_training_args=data_training_args, 54 | ) 55 | _spider_pre_process_function = lambda batch, max_source_length, max_target_length, mode: spider_pre_process_function( 56 | batch=batch, 57 | max_source_length=max_source_length, 58 | max_target_length=max_target_length, 59 | mode=mode, 60 | data_training_args=data_training_args, 61 | tokenizer=tokenizer, 62 | ) 63 | 64 | 65 | 66 | _geoquery_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset( 67 | path=data_args.dataset_paths["geoquery"], cache_dir=model_args.cache_dir 68 | ) 69 | _geoquery_metric: Callable[[], Metric] = lambda: datasets.load.load_metric( 70 | path=data_args.metric_paths["spider"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir 71 | ) 72 | _geoquery_add_serialized_schema = lambda ex, mode: geoquery_add_serialized_schema( 73 | ex=ex, 74 | mode=mode, 75 | data_training_args=data_training_args, 76 | ) 77 | _geoquery_pre_process_function = lambda batch, max_source_length, max_target_length, mode: geoquery_pre_process_function( 78 | batch=batch, 79 | max_source_length=max_source_length, 80 | max_target_length=max_target_length, 81 | mode=mode, 82 | data_training_args=data_training_args, 83 | tokenizer=tokenizer, 84 | ) 85 | 86 | 87 | _cosql_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset( 88 | path=data_args.dataset_paths["cosql"], cache_dir=model_args.cache_dir 89 | ) 90 | _cosql_metric: Callable[[], Metric] = lambda: datasets.load.load_metric( 91 | path=data_args.metric_paths["cosql"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir 92 | ) 93 | _cosql_add_serialized_schema = lambda ex, mode: cosql_add_serialized_schema( 94 | ex=ex, 95 | mode=mode, 96 | data_training_args=data_training_args, 97 | ) 98 | _cosql_pre_process_function = lambda batch, max_source_length, max_target_length, mode: cosql_pre_process_function( 99 | batch=batch, 100 | max_source_length=max_source_length, 101 | max_target_length=max_target_length, 102 | data_training_args=data_training_args, 103 | mode=mode, 104 | tokenizer=tokenizer, 105 | ) 106 | 107 | _prepare_splits_kwargs = { 108 | "data_args": data_args, 109 | "training_args": training_args, 110 | "data_training_args": data_training_args, 111 | } 112 | 113 | if data_args.dataset == "spider": 114 | metric = _spider_metric() 115 | dataset_splits = prepare_splits( 116 | dataset_dict=_spider_dataset_dict(), 117 | add_serialized_schema=_spider_add_serialized_schema, 118 | pre_process_function=_spider_pre_process_function, 119 | **_prepare_splits_kwargs, 120 | ) 121 | elif data_args.dataset == "geoquery": 122 | metric = _geoquery_metric() 123 | dataset_splits = prepare_splits( 124 | dataset_dict=_geoquery_dataset_dict(), 125 | add_serialized_schema=_geoquery_add_serialized_schema, 126 | pre_process_function=_geoquery_pre_process_function, 127 | **_prepare_splits_kwargs 128 | ) 129 | elif data_args.dataset == "cosql": 130 | metric = _cosql_metric() 131 | dataset_splits = prepare_splits( 132 | dataset_dict=_cosql_dataset_dict(), 133 | add_serialized_schema=_cosql_add_serialized_schema, 134 | pre_process_function=_cosql_pre_process_function, 135 | **_prepare_splits_kwargs, 136 | ) 137 | else: 138 | raise NotImplementedError() 139 | if dataset_splits.train_split is not None: 140 | _log_duplicate_count(dataset=dataset_splits.train_split.dataset, dataset_name=data_args.dataset, split="train") 141 | if dataset_splits.eval_split is not None: 142 | _log_duplicate_count(dataset=dataset_splits.eval_split.dataset, dataset_name=data_args.dataset, split="eval") 143 | if dataset_splits.test_splits is not None: 144 | for section, split in dataset_splits.test_splits.items(): 145 | _log_duplicate_count(dataset=split.dataset, dataset_name=data_args.dataset, split=section) 146 | 147 | return metric, dataset_splits 148 | -------------------------------------------------------------------------------- /src/utils/get_tables.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import sqlite3 5 | from os import listdir, makedirs 6 | from os.path import isfile, isdir, join, split, exists, splitext 7 | from nltk import word_tokenize, tokenize 8 | import traceback 9 | 10 | EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"} 11 | 12 | 13 | def convert_fk_index(data): 14 | fk_holder = [] 15 | for fk in data["foreign_keys"]: 16 | tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] 17 | ref_cid, cid = None, None 18 | try: 19 | tid = data["table_names_original"].index(tn) 20 | ref_tid = data["table_names_original"].index(ref_tn) 21 | 22 | for i, (tab_id, col_org) in enumerate(data["column_names_original"]): 23 | if tab_id == ref_tid and ref_col == col_org: 24 | ref_cid = i 25 | elif tid == tab_id and col == col_org: 26 | cid = i 27 | if ref_cid and cid: 28 | fk_holder.append([cid, ref_cid]) 29 | except: 30 | traceback.print_exc() 31 | print("table_names_original: ", data["table_names_original"]) 32 | print("finding tab name: ", tn, ref_tn) 33 | sys.exit() 34 | return fk_holder 35 | 36 | 37 | def dump_db_json_schema(db, f): 38 | """read table and column info""" 39 | 40 | conn = sqlite3.connect(db) 41 | conn.execute("pragma foreign_keys=ON") 42 | cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") 43 | 44 | data = { 45 | "db_id": f, 46 | "table_names_original": [], 47 | "table_names": [], 48 | "column_names_original": [(-1, "*")], 49 | "column_names": [(-1, "*")], 50 | "column_types": ["text"], 51 | "primary_keys": [], 52 | "foreign_keys": [], 53 | } 54 | 55 | fk_holder = [] 56 | for i, item in enumerate(cursor.fetchall()): 57 | table_name = item[0] 58 | data["table_names_original"].append(table_name) 59 | data["table_names"].append(table_name.lower().replace("_", " ")) 60 | fks = conn.execute( 61 | "PRAGMA foreign_key_list('{}') ".format(table_name) 62 | ).fetchall() 63 | # print("db:{} table:{} fks:{}".format(f,table_name,fks)) 64 | fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) 65 | cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) 66 | for j, col in enumerate(cur.fetchall()): 67 | data["column_names_original"].append((i, col[1])) 68 | data["column_names"].append((i, col[1].lower().replace("_", " "))) 69 | # varchar, '' -> text, int, numeric -> integer, 70 | col_type = col[2].lower() 71 | if ( 72 | "char" in col_type 73 | or col_type == "" 74 | or "text" in col_type 75 | or "var" in col_type 76 | ): 77 | data["column_types"].append("text") 78 | elif ( 79 | "int" in col_type 80 | or "numeric" in col_type 81 | or "decimal" in col_type 82 | or "number" in col_type 83 | or "id" in col_type 84 | or "real" in col_type 85 | or "double" in col_type 86 | or "float" in col_type 87 | ): 88 | data["column_types"].append("number") 89 | elif "date" in col_type or "time" in col_type or "year" in col_type: 90 | data["column_types"].append("time") 91 | elif "boolean" in col_type: 92 | data["column_types"].append("boolean") 93 | else: 94 | data["column_types"].append("others") 95 | 96 | if col[5] == 1: 97 | data["primary_keys"].append(len(data["column_names"]) - 1) 98 | 99 | data["foreign_keys"] = fk_holder 100 | data["foreign_keys"] = convert_fk_index(data) 101 | 102 | return data 103 | 104 | -------------------------------------------------------------------------------- /src/utils/process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | import sys 30 | 31 | from nltk import word_tokenize 32 | 33 | 34 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 35 | JOIN_KEYWORDS = ('join', 'on', 'as') 36 | 37 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 38 | UNIT_OPS = ('none', '-', '+', "*", '/') 39 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 40 | TABLE_TYPE = { 41 | 'sql': "sql", 42 | 'table_unit': "table_unit", 43 | } 44 | 45 | COND_OPS = ('and', 'or') 46 | SQL_OPS = ('intersect', 'union', 'except') 47 | ORDER_OPS = ('desc', 'asc') 48 | 49 | 50 | 51 | class Schema: 52 | """ 53 | Simple schema which maps table&column to a unique identifier 54 | """ 55 | def __init__(self, schema): 56 | self._schema = schema 57 | self._idMap = self._map(self._schema) 58 | 59 | @property 60 | def schema(self): 61 | return self._schema 62 | 63 | @property 64 | def idMap(self): 65 | return self._idMap 66 | 67 | def _map(self, schema): 68 | idMap = {'*': "__all__"} 69 | id = 1 70 | for key, vals in schema.items(): 71 | for val in vals: 72 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 73 | id += 1 74 | 75 | for key in schema: 76 | idMap[key.lower()] = "__" + key.lower() + "__" 77 | id += 1 78 | 79 | return idMap 80 | 81 | 82 | def get_schema(db): 83 | """ 84 | Get database's schema, which is a dict with table name as key 85 | and list of column names as value 86 | :param db: database path 87 | :return: schema dict 88 | """ 89 | 90 | schema = {} 91 | conn = sqlite3.connect(db) 92 | cursor = conn.cursor() 93 | 94 | # fetch table names 95 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 96 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 97 | 98 | # fetch table info 99 | for table in tables: 100 | cursor.execute("PRAGMA table_info({})".format(table)) 101 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 102 | 103 | return schema 104 | 105 | 106 | def get_schema_from_json(fpath): 107 | with open(fpath) as f: 108 | data = json.load(f) 109 | 110 | schema = {} 111 | for entry in data: 112 | table = str(entry['table'].lower()) 113 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 114 | schema[table] = cols 115 | 116 | return schema 117 | 118 | 119 | def tokenize(string): 120 | string = str(string) 121 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 122 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 123 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 124 | 125 | # keep string value as token 126 | vals = {} 127 | for i in range(len(quote_idxs)-1, -1, -2): 128 | qidx1 = quote_idxs[i-1] 129 | qidx2 = quote_idxs[i] 130 | val = string[qidx1: qidx2+1] 131 | key = "__val_{}_{}__".format(qidx1, qidx2) 132 | string = string[:qidx1] + key + string[qidx2+1:] 133 | vals[key] = val 134 | 135 | toks = [word.lower() for word in word_tokenize(string)] 136 | # replace with string value token 137 | for i in range(len(toks)): 138 | if toks[i] in vals: 139 | toks[i] = vals[toks[i]] 140 | 141 | # find if there exists !=, >=, <= 142 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 143 | eq_idxs.reverse() 144 | prefix = ('!', '>', '<') 145 | for eq_idx in eq_idxs: 146 | pre_tok = toks[eq_idx-1] 147 | if pre_tok in prefix: 148 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 149 | 150 | return toks 151 | 152 | 153 | def scan_alias(toks): 154 | """Scan the index of 'as' and build the map for all alias""" 155 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 156 | alias = {} 157 | for idx in as_idxs: 158 | alias[toks[idx+1]] = toks[idx-1] 159 | return alias 160 | 161 | 162 | def get_tables_with_alias(schema, toks): 163 | tables = scan_alias(toks) 164 | for key in schema: 165 | assert key not in tables, "Alias {} has the same name in table".format(key) 166 | tables[key] = key 167 | return tables 168 | 169 | 170 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 171 | """ 172 | :returns next idx, column id 173 | """ 174 | tok = toks[start_idx] 175 | if tok == "*": 176 | return start_idx + 1, schema.idMap[tok] 177 | 178 | if '.' in tok: # if token is a composite 179 | alias, col = tok.split('.') 180 | key = tables_with_alias[alias] + "." + col 181 | return start_idx+1, schema.idMap[key] 182 | 183 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 184 | 185 | for alias in default_tables: 186 | table = tables_with_alias[alias] 187 | if tok in schema.schema[table]: 188 | key = table + "." + tok 189 | return start_idx+1, schema.idMap[key] 190 | 191 | assert False, "Error col: {}".format(tok) 192 | 193 | 194 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 195 | """ 196 | :returns next idx, (agg_op id, col_id) 197 | """ 198 | idx = start_idx 199 | len_ = len(toks) 200 | isBlock = False 201 | isDistinct = False 202 | if toks[idx] == '(': 203 | isBlock = True 204 | idx += 1 205 | 206 | if toks[idx] in AGG_OPS: 207 | agg_id = AGG_OPS.index(toks[idx]) 208 | idx += 1 209 | assert idx < len_ and toks[idx] == '(' 210 | idx += 1 211 | if toks[idx] == "distinct": 212 | idx += 1 213 | isDistinct = True 214 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 215 | assert idx < len_ and toks[idx] == ')' 216 | idx += 1 217 | return idx, (agg_id, col_id, isDistinct) 218 | 219 | if toks[idx] == "distinct": 220 | idx += 1 221 | isDistinct = True 222 | agg_id = AGG_OPS.index("none") 223 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 224 | 225 | if isBlock: 226 | assert toks[idx] == ')' 227 | idx += 1 # skip ')' 228 | 229 | return idx, (agg_id, col_id, isDistinct) 230 | 231 | 232 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 233 | idx = start_idx 234 | len_ = len(toks) 235 | isBlock = False 236 | if toks[idx] == '(': 237 | isBlock = True 238 | idx += 1 239 | 240 | col_unit1 = None 241 | col_unit2 = None 242 | unit_op = UNIT_OPS.index('none') 243 | 244 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 245 | if idx < len_ and toks[idx] in UNIT_OPS: 246 | unit_op = UNIT_OPS.index(toks[idx]) 247 | idx += 1 248 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 249 | 250 | if isBlock: 251 | assert toks[idx] == ')' 252 | idx += 1 # skip ')' 253 | 254 | return idx, (unit_op, col_unit1, col_unit2) 255 | 256 | 257 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 258 | """ 259 | :returns next idx, table id, table name 260 | """ 261 | idx = start_idx 262 | len_ = len(toks) 263 | key = tables_with_alias[toks[idx]] 264 | 265 | if idx + 1 < len_ and toks[idx+1] == "as": 266 | idx += 3 267 | else: 268 | idx += 1 269 | 270 | return idx, schema.idMap[key], key 271 | 272 | 273 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 274 | idx = start_idx 275 | len_ = len(toks) 276 | 277 | isBlock = False 278 | if toks[idx] == '(': 279 | isBlock = True 280 | idx += 1 281 | 282 | if toks[idx] == 'select': 283 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 284 | elif "\"" in toks[idx]: # token is a string value 285 | val = toks[idx] 286 | idx += 1 287 | else: 288 | try: 289 | val = float(toks[idx]) 290 | idx += 1 291 | except: 292 | end_idx = idx 293 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 294 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 295 | end_idx += 1 296 | 297 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 298 | idx = end_idx 299 | 300 | if isBlock: 301 | assert toks[idx] == ')' 302 | idx += 1 303 | 304 | return idx, val 305 | 306 | 307 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 308 | idx = start_idx 309 | len_ = len(toks) 310 | conds = [] 311 | 312 | while idx < len_: 313 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 314 | not_op = False 315 | if toks[idx] == 'not': 316 | not_op = True 317 | idx += 1 318 | 319 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 320 | op_id = WHERE_OPS.index(toks[idx]) 321 | idx += 1 322 | val1 = val2 = None 323 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 324 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 325 | assert toks[idx] == 'and' 326 | idx += 1 327 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 328 | else: # normal case: single value 329 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 330 | val2 = None 331 | 332 | conds.append((not_op, op_id, val_unit, val1, val2)) 333 | 334 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 335 | break 336 | 337 | if idx < len_ and toks[idx] in COND_OPS: 338 | conds.append(toks[idx]) 339 | idx += 1 # skip and/or 340 | 341 | return idx, conds 342 | 343 | 344 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 345 | idx = start_idx 346 | len_ = len(toks) 347 | 348 | assert toks[idx] == 'select', "'select' not found" 349 | idx += 1 350 | isDistinct = False 351 | if idx < len_ and toks[idx] == 'distinct': 352 | idx += 1 353 | isDistinct = True 354 | val_units = [] 355 | 356 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 357 | agg_id = AGG_OPS.index("none") 358 | if toks[idx] in AGG_OPS: 359 | agg_id = AGG_OPS.index(toks[idx]) 360 | idx += 1 361 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 362 | val_units.append((agg_id, val_unit)) 363 | if idx < len_ and toks[idx] == ',': 364 | idx += 1 # skip ',' 365 | 366 | return idx, (isDistinct, val_units) 367 | 368 | 369 | def parse_from(toks, start_idx, tables_with_alias, schema): 370 | """ 371 | Assume in the from clause, all table units are combined with join 372 | """ 373 | assert 'from' in toks[start_idx:], "'from' not found" 374 | 375 | len_ = len(toks) 376 | idx = toks.index('from', start_idx) + 1 377 | default_tables = [] 378 | table_units = [] 379 | conds = [] 380 | 381 | while idx < len_: 382 | isBlock = False 383 | if toks[idx] == '(': 384 | isBlock = True 385 | idx += 1 386 | 387 | if toks[idx] == 'select': 388 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 389 | table_units.append((TABLE_TYPE['sql'], sql)) 390 | else: 391 | if idx < len_ and toks[idx] == 'join': 392 | idx += 1 # skip join 393 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 394 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 395 | default_tables.append(table_name) 396 | if idx < len_ and toks[idx] == "on": 397 | idx += 1 # skip on 398 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 399 | if len(conds) > 0: 400 | conds.append('and') 401 | conds.extend(this_conds) 402 | 403 | if isBlock: 404 | assert toks[idx] == ')' 405 | idx += 1 406 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 407 | break 408 | 409 | return idx, table_units, conds, default_tables 410 | 411 | 412 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 413 | idx = start_idx 414 | len_ = len(toks) 415 | 416 | if idx >= len_ or toks[idx] != 'where': 417 | return idx, [] 418 | 419 | idx += 1 420 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 421 | return idx, conds 422 | 423 | 424 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 425 | idx = start_idx 426 | len_ = len(toks) 427 | col_units = [] 428 | 429 | if idx >= len_ or toks[idx] != 'group': 430 | return idx, col_units 431 | 432 | idx += 1 433 | assert toks[idx] == 'by' 434 | idx += 1 435 | 436 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 437 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 438 | col_units.append(col_unit) 439 | if idx < len_ and toks[idx] == ',': 440 | idx += 1 # skip ',' 441 | else: 442 | break 443 | 444 | return idx, col_units 445 | 446 | 447 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 448 | idx = start_idx 449 | len_ = len(toks) 450 | val_units = [] 451 | order_type = 'asc' # default type is 'asc' 452 | 453 | if idx >= len_ or toks[idx] != 'order': 454 | return idx, val_units 455 | 456 | idx += 1 457 | assert toks[idx] == 'by' 458 | idx += 1 459 | 460 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 461 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 462 | val_units.append(val_unit) 463 | if idx < len_ and toks[idx] in ORDER_OPS: 464 | order_type = toks[idx] 465 | idx += 1 466 | if idx < len_ and toks[idx] == ',': 467 | idx += 1 # skip ',' 468 | else: 469 | break 470 | 471 | return idx, (order_type, val_units) 472 | 473 | 474 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 475 | idx = start_idx 476 | len_ = len(toks) 477 | 478 | if idx >= len_ or toks[idx] != 'having': 479 | return idx, [] 480 | 481 | idx += 1 482 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 483 | return idx, conds 484 | 485 | 486 | def parse_limit(toks, start_idx): 487 | idx = start_idx 488 | len_ = len(toks) 489 | 490 | if idx < len_ and toks[idx] == 'limit': 491 | idx += 2 492 | # make limit value can work, cannot assume put 1 as a fake limit number 493 | if type(toks[idx-1]) != int: 494 | return idx, 1 495 | 496 | return idx, int(toks[idx-1]) 497 | 498 | return idx, None 499 | 500 | 501 | def parse_sql(toks, start_idx, tables_with_alias, schema): 502 | isBlock = False # indicate whether this is a block of sql/sub-sql 503 | len_ = len(toks) 504 | idx = start_idx 505 | 506 | sql = {} 507 | if toks[idx] == '(': 508 | isBlock = True 509 | idx += 1 510 | 511 | # parse from clause in order to get default tables 512 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 513 | sql['from'] = {'table_units': table_units, 'conds': conds} 514 | # select clause 515 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 516 | idx = from_end_idx 517 | sql['select'] = select_col_units 518 | # where clause 519 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 520 | sql['where'] = where_conds 521 | # group by clause 522 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 523 | sql['groupBy'] = group_col_units 524 | # having clause 525 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 526 | sql['having'] = having_conds 527 | # order by clause 528 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 529 | sql['orderBy'] = order_col_units 530 | # limit clause 531 | idx, limit_val = parse_limit(toks, idx) 532 | sql['limit'] = limit_val 533 | 534 | idx = skip_semicolon(toks, idx) 535 | if isBlock: 536 | assert toks[idx] == ')' 537 | idx += 1 # skip ')' 538 | idx = skip_semicolon(toks, idx) 539 | 540 | # intersect/union/except clause 541 | for op in SQL_OPS: # initialize IUE 542 | sql[op] = None 543 | if idx < len_ and toks[idx] in SQL_OPS: 544 | sql_op = toks[idx] 545 | idx += 1 546 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 547 | sql[sql_op] = IUE_sql 548 | return idx, sql 549 | 550 | 551 | def load_data(fpath): 552 | with open(fpath) as f: 553 | data = json.load(f) 554 | return data 555 | 556 | 557 | def get_sql(schema, query): 558 | toks = tokenize(query) 559 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 560 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 561 | 562 | return sql 563 | 564 | 565 | def skip_semicolon(toks, start_idx): 566 | idx = start_idx 567 | while idx < len(toks) and toks[idx] == ";": 568 | idx += 1 569 | return idx 570 | -------------------------------------------------------------------------------- /src/utils/spider.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from typing import Optional 4 | from datasets.arrow_dataset import Dataset 5 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 6 | from .dataset import normalize, serialize_schema, combine_SC 7 | from .args import * 8 | from .trainer import Seq2SeqTrainer, EvalPrediction 9 | from .process_sql import get_schema, Schema, get_sql 10 | import os 11 | import random 12 | import re 13 | import shlex 14 | 15 | pattern_1 = re.compile('(\(|\))') 16 | pattern_4 = re.compile('(\()') 17 | pattern_5 = re.compile('(\))') 18 | pattern_2 = re.compile('(,)') 19 | pattern_3 = re.compile('(>=|<=|>|<|=)') 20 | 21 | sql_clauses = ['select', 'from', 'where', 'group', 'having', 'order', 'limit', 'intersect', 'union', 'except'] 22 | sql_ops_space = ['>', '<', '=', 'like', '!=', '-', '+', 'between', 'and', 'or', 'not', 'in', ')', 'by', 'distinct', '>=', '<=', '<>'] 23 | sql_ops_no_space = ['count', 'avg', 'sum', 'max', 'min', '(', '!', 'desc', 'asc'] 24 | sql_marks = [','] 25 | 26 | def lower_( 27 | word: str, 28 | ) -> str: 29 | if '"' in word or "'" in word: 30 | return word 31 | else: 32 | return word.lower() 33 | 34 | def tok_process( 35 | toks: list, 36 | ) -> list: 37 | processed_tok_list = [] 38 | i = 0 39 | while i < len(toks): 40 | if toks[i] == "``" and toks[i+2] == "''": 41 | temp = f'"{toks[i+1]}"' 42 | processed_tok_list.append(temp) 43 | i += 3 44 | continue 45 | else: 46 | processed_tok_list.append(toks[i]) 47 | i += 1 48 | return [lower_(x) for x in processed_tok_list] 49 | 50 | 51 | 52 | def spider_get_input( 53 | question: str, 54 | serialized_schema: str, 55 | prefix: str, 56 | ) -> str: 57 | return prefix + question.strip() + " | " + serialized_schema.strip() 58 | 59 | 60 | def spider_get_target( 61 | query: str, 62 | db_id: str, 63 | normalize_query: bool, 64 | target_with_db_id: bool, 65 | ) -> str: 66 | _normalize = normalize if normalize_query else (lambda x: x) 67 | return f"{db_id} | {_normalize(query)}" if target_with_db_id else _normalize(query) 68 | 69 | 70 | def spider_add_serialized_schema(ex: dict, mode: str, data_training_args: DataTrainingArguments) -> dict: 71 | serialized_schema = serialize_schema( 72 | question=ex["question"], 73 | db_path=ex["db_path"], 74 | db_id=ex["db_id"], 75 | db_column_names=ex["db_column_names"], 76 | db_table_names=ex["db_table_names"], 77 | schema_serialization_type=data_training_args.schema_serialization_type, 78 | schema_serialization_randomized=data_training_args.schema_serialization_randomized, 79 | schema_serialization_with_db_id=data_training_args.schema_serialization_with_db_id, 80 | schema_serialization_with_db_content=data_training_args.schema_serialization_with_db_content, 81 | normalize_query=data_training_args.normalize_query, 82 | ) 83 | return {"serialized_schema": serialized_schema} 84 | 85 | 86 | def spider_pre_process_function( 87 | batch: dict, 88 | max_source_length: Optional[int], 89 | max_target_length: Optional[int], 90 | mode: Optional[str], 91 | data_training_args: DataTrainingArguments, 92 | tokenizer: PreTrainedTokenizerBase, 93 | ) -> dict: 94 | prefix = data_training_args.source_prefix if data_training_args.source_prefix is not None else "question: " 95 | if data_training_args.use_decomposition: 96 | inputs = [] 97 | targets = [] 98 | if data_training_args.stage == 'content': 99 | eval_format_list = [] 100 | with open(data_training_args.structure_path) as f: 101 | info = json.load(f) 102 | for item in info: 103 | eval_format_list.append(item['prediction']) 104 | print(f"load {len(eval_format_list)} eval_formats from {data_training_args.structure_path}") 105 | if len(batch['question']) == 1000: 106 | count = 0 107 | else: 108 | count = 1000 109 | 110 | for question, serialized_schema, db_id, query, query_toks, db_column_names in zip(batch["question"], batch["serialized_schema"], batch["db_id"], batch["query"], batch["query_toks"], batch["db_column_names"]): 111 | input_str = spider_get_input(question=question, serialized_schema=serialized_schema, prefix=prefix) 112 | #input_str = input_str + ' | Translate ' 113 | column_names = [x.lower() for x in db_column_names['column_name']] 114 | 115 | lex = shlex.shlex(query) 116 | lex.whitespace = ' ' 117 | lex.quotes=['"', "'"] 118 | lex.whitespace_split = True 119 | query_toks = list(lex) 120 | query_tok_list = tok_process(query_toks) 121 | for idx, tok in enumerate(query_tok_list): 122 | if '"' in tok or "'" in tok: 123 | continue 124 | if len(tok) > 1 and ',' in tok and tok not in column_names and query_tok_list[idx-1] not in sql_ops_space: 125 | res = pattern_2.split(tok) 126 | query_tok_list[idx:idx+1] = res 127 | if '(' in query_tok_list[idx] and ')' in query_tok_list[idx]: 128 | res = pattern_1.split(query_tok_list[idx]) 129 | query_tok_list[idx:idx+1] = res 130 | elif '(' in query_tok_list[idx] and ('select' in query_tok_list[idx] or 'distinct' in query_tok_list[idx] or 'in' in query_tok_list[idx] or 'count' in query_tok_list[idx]): 131 | res = pattern_4.split(query_tok_list[idx]) 132 | query_tok_list[idx:idx+1] = res 133 | elif len(query_tok_list[idx]) > 1 and ')' == query_tok_list[idx][-1]: 134 | res = pattern_5.split(query_tok_list[idx]) 135 | query_tok_list[idx:idx+1] = res 136 | if ('>' in query_tok_list[idx] or '<' in query_tok_list[idx] or '>=' in query_tok_list[idx] or '<=' in query_tok_list[idx] or '=' in query_tok_list[idx]) and query_tok_list[idx][0] not in ['>', '<', '=']: 137 | res = pattern_3.split(query_tok_list[idx]) 138 | query_tok_list[idx:idx+1] = res 139 | for idx, tok in enumerate(query_tok_list): 140 | if tok == '': 141 | del query_tok_list[idx] 142 | 143 | # query_tok_list = tok_process(query_toks) 144 | sub_query_format_list = [] 145 | content_label = '' 146 | sub_query_list = [] 147 | select_from_record = [] 148 | sub_query_format = '' 149 | sub_query = '' 150 | last_tok = None 151 | select_num = 0 152 | left_bracket = 0 153 | right_bracket = 0 154 | if query_tok_list[-1] == ';': 155 | query_tok_list = query_tok_list[:-1] 156 | for idx, query_tok in enumerate(query_tok_list): 157 | if query_tok in sql_clauses: 158 | if query_tok == 'select': 159 | select_num += 1 160 | elif (query_tok == 'group' or query_tok == 'order') and query_tok_list[idx+1] != 'by': 161 | if query_tok in column_names: 162 | if idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in [',', ')']: 163 | sub_query_format += '[col]' 164 | else: 165 | sub_query_format += '[col] ' 166 | content_label += '[col] ' + query_tok + ' ' 167 | sub_query += query_tok + ' ' 168 | else: 169 | print("error:", query_tok) 170 | continue 171 | 172 | if sub_query_format != '' and sub_query != '': 173 | if 'select' in sub_query_format: 174 | select_from_record.append(1) 175 | elif 'from' in sub_query_format: 176 | select_from_record.append(2) 177 | else: 178 | select_from_record.append(0) 179 | sub_query_format_list.append(sub_query_format) 180 | sub_query_list.append(sub_query) 181 | sub_query_format = '' 182 | sub_query = '' 183 | if query_tok == 'from': 184 | sub_query_format += 'from [tab]' 185 | content_label += '[tab] ' 186 | else: 187 | sub_query_format += query_tok + ' ' 188 | last_tok = 'sql_clauses' 189 | sub_query += query_tok + ' ' 190 | elif sub_query_format == 'from [tab]': 191 | last_tok = 'from [tab]' 192 | sub_query += query_tok + ' ' 193 | if query_tok not in [')', '(']: 194 | content_label += query_tok + ' ' 195 | continue 196 | elif query_tok in sql_ops_space: 197 | if query_tok == ')': 198 | right_bracket += 1 199 | if ((query_tok == '>' or query_tok == '<') and query_tok_list[idx+1] == '=') or (query_tok == ')' and (idx + 1 == len(query_tok_list) or query_tok_list[idx+1] == ',')): 200 | # >= or <= ), 201 | sub_query_format += query_tok 202 | sub_query += query_tok 203 | else: 204 | sub_query_format += query_tok + ' ' 205 | sub_query += query_tok + ' ' 206 | last_tok = 'op' 207 | elif query_tok in sql_ops_no_space: 208 | if query_tok == '(': 209 | left_bracket += 1 210 | sub_query_format += query_tok 211 | sub_query += query_tok 212 | elif query_tok in column_names or '.' in query_tok: 213 | if last_tok == 'val': 214 | content_label += query_tok + ' ' 215 | continue 216 | if idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in [',', ')']: 217 | sub_query_format += '[col]' 218 | sub_query += query_tok 219 | else: 220 | sub_query_format += '[col] ' 221 | sub_query += query_tok + ' ' 222 | content_label += '[col] ' + query_tok + ' ' 223 | last_tok = 'col' 224 | elif query_tok in sql_marks: 225 | sub_query_format += query_tok + ' ' 226 | sub_query += query_tok + ' ' 227 | last_tok = 'mark' 228 | else: 229 | if last_tok != 'val': 230 | sub_query_format += '[val] ' 231 | content_label += '[val] ' 232 | if query_tok == '``': 233 | sub_query += '"' 234 | content_label += '"' 235 | elif query_tok == "''": 236 | sub_query += '" ' 237 | content_label += '" ' 238 | elif query_tok == "'": 239 | sub_query += "' " 240 | content_label += "' " 241 | elif last_tok == 'val' and (idx + 1 == len(query_tok_list) or query_tok_list[idx+1] in ["'", '"', '``', "''"]): 242 | sub_query += query_tok 243 | content_label += query_tok 244 | else: 245 | sub_query += query_tok + ' ' 246 | content_label += query_tok + ' ' 247 | last_tok = 'val' 248 | 249 | if select_num > 1 and left_bracket > right_bracket: 250 | sub_query_format += ')' 251 | sub_query_format_list.append(sub_query_format) 252 | sub_query_list.append(sub_query) 253 | if data_training_args.stage == 'structure': 254 | structure = normalize(' '.join(sub_query_format_list)) 255 | inputs.append(data_training_args.schema_serialization_with_prompt + ' | ' + input_str) 256 | target = spider_get_target( 257 | query=structure, 258 | db_id=db_id, 259 | normalize_query=True, 260 | target_with_db_id=False, 261 | ) 262 | targets.append(target) 263 | 264 | elif data_training_args.stage == 'content': 265 | if mode == 'eval': 266 | input_str = data_training_args.schema_serialization_with_prompt + eval_format_list[count] + ' | ' + input_str 267 | else: 268 | input_str = data_training_args.schema_serialization_with_prompt + ' '.join(sub_query_format_list) + ' | ' + input_str 269 | inputs.append(input_str) 270 | target = content_label 271 | targets.append(target) 272 | count += 1 273 | 274 | else: 275 | inputs = [ 276 | spider_get_input(question=question, serialized_schema=serialized_schema, prefix=prefix) 277 | for question, serialized_schema in zip(batch["question"], batch["serialized_schema"]) 278 | ] 279 | targets = [ 280 | spider_get_target( 281 | query=query, 282 | db_id=db_id, 283 | normalize_query=data_training_args.normalize_query, 284 | target_with_db_id=data_training_args.target_with_db_id, 285 | ) 286 | for db_id, query in zip(batch["db_id"], batch["query"]) 287 | ] 288 | print(f"{mode}: {len(inputs)}") 289 | 290 | model_inputs: dict = tokenizer( 291 | inputs, 292 | max_length=max_source_length, 293 | padding=False, 294 | truncation=True, 295 | return_overflowing_tokens=False, 296 | ) 297 | 298 | # Setup the tokenizer for targets 299 | with tokenizer.as_target_tokenizer(): 300 | labels = tokenizer( 301 | targets, 302 | max_length=max_target_length, 303 | padding=False, 304 | truncation=True, 305 | return_overflowing_tokens=False, 306 | ) 307 | 308 | model_inputs["labels"] = labels["input_ids"] 309 | return model_inputs 310 | 311 | 312 | class SpiderTrainer(Seq2SeqTrainer): 313 | def _post_process_function( 314 | self, examples: Dataset, features: Dataset, predictions: np.ndarray, stage: str 315 | ) -> EvalPrediction: 316 | inputs = self.tokenizer.batch_decode([f["input_ids"] for f in features], skip_special_tokens=True) 317 | label_ids = [f["labels"] for f in features] 318 | if self.ignore_pad_token_for_loss: 319 | # Replace -100 in the labels as we can't decode them. 320 | _label_ids = np.where(label_ids != -100, label_ids, self.tokenizer.pad_token_id) 321 | decoded_label_ids = self.tokenizer.batch_decode(_label_ids, skip_special_tokens=True) 322 | metas = [ 323 | { 324 | "query": x["query"], 325 | "question": x["question"], 326 | "context": context, 327 | "label": label, 328 | "db_id": x["db_id"], 329 | "db_path": x["db_path"], 330 | "db_table_names": x["db_table_names"], 331 | "db_column_names": x["db_column_names"], 332 | "db_foreign_keys": x["db_foreign_keys"], 333 | } 334 | for x, context, label in zip(examples, inputs, decoded_label_ids) 335 | ] 336 | predictions = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) 337 | 338 | assert len(metas) == len(predictions) 339 | if self.stage == 'content': 340 | final_pred_sqls = [] 341 | hypotheses_path = os.path.join(self.args.output_dir, "hypotheses.json") 342 | if os.path.exists(hypotheses_path): 343 | # sentence-level check 344 | with open(hypotheses_path) as f: 345 | hypotheses = json.load(f) 346 | for idx, item in enumerate(hypotheses): 347 | db_id, structure = item["structure"].split(" | ") 348 | db = os.path.join(metas[idx]["db_path"], db_id, f'{db_id}.sqlite') 349 | schema = Schema(get_schema(db)) 350 | final_pred_sql = None 351 | for hypothesis in item["topk_preds"]: 352 | try: 353 | pred_sql = combine_SC(content=hypothesis, input="", structure=structure) 354 | parse_sql = get_sql(schema, pred_sql) 355 | final_pred_sql = pred_sql 356 | break 357 | except: 358 | continue 359 | if final_pred_sql == None: 360 | # default to the first one 361 | final_pred_sql = combine_SC(content=item["topk_preds"][0], input="", structure=structure) 362 | final_pred_sqls.append(final_pred_sql) 363 | 364 | os.remove(hypotheses_path) 365 | else: 366 | for pred_content, meta in zip(predictions, metas): 367 | final_pred_sqls.append(combine_SC(pred_content, meta['context'])) 368 | # write predict sql 369 | with open(f"{self.args.output_dir}/predict_sql.txt", "w") as f: 370 | for final_pred_sql in final_pred_sqls: 371 | f.write(final_pred_sql+"\n") 372 | 373 | with open(f"{self.args.output_dir}/content_{stage}.json", "w") as f: 374 | json.dump( 375 | [dict(**{"input": meta['context']}, **{"prediction": prediction}, **{"label": label}, **{"score": prediction==label}, **{"pred_sql": final_pred_sql}, **{"gold_sql": meta['query']}) for meta, prediction, final_pred_sql, label in zip(metas, predictions, final_pred_sqls, decoded_label_ids)], 376 | f, 377 | indent=4, 378 | ) 379 | return EvalPrediction(predictions=final_pred_sqls, label_ids=decoded_label_ids, metas=metas) 380 | elif self.stage == 'structure': 381 | for idx in range(len(predictions)): 382 | if 'before' in predictions[idx]: 383 | predictions[idx] = predictions[idx].replace('before', '<') 384 | if 'after' in predictions[idx]: 385 | predictions[idx] = predictions[idx].replace('after', '>') 386 | return EvalPrediction(predictions=predictions, label_ids=decoded_label_ids, metas=metas) 387 | 388 | 389 | def _compute_metrics(self, eval_prediction: EvalPrediction) -> dict: 390 | #predictions, label_ids, metas = eval_prediction 391 | predictions, label_ids, metas = eval_prediction 392 | if self.target_with_db_id: 393 | # Remove database id from all predictions 394 | predictions = [pred.split("|", 1)[-1].strip() for pred in predictions] 395 | references = metas 396 | if self.stage == 'structure': 397 | accuracy = [] 398 | accuracy.extend( 399 | ( 400 | pred.lower() == actual.lower() 401 | for pred, actual in zip(predictions, label_ids) 402 | ) 403 | ) 404 | eval_metric = np.mean(accuracy) 405 | test_suite = dict() 406 | if eval_metric >= self.best_acc: 407 | with open(f"{self.args.output_dir}/structure.json", "w") as f: 408 | json.dump( 409 | [dict(**{"input": meta['context']}, **{"prediction": prediction}, **{"label": label}, **{"score": prediction==label}) for meta, prediction, label in zip(metas, predictions, label_ids)], 410 | f, 411 | indent=4, 412 | ) 413 | return {**{"exact_match": eval_metric}, **test_suite} 414 | elif self.stage == 'content': 415 | return self.metric.compute(predictions=predictions, references=references) 416 | else: 417 | raise NotImplementedError() 418 | -------------------------------------------------------------------------------- /src/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Dict, List, Optional, NamedTuple, Union, Any, Tuple 3 | import transformers.trainer_seq2seq 4 | from transformers.deepspeed import is_deepspeed_zero3_enabled 5 | from transformers.trainer_utils import PredictionOutput, speed_metrics 6 | from torch.utils.data import DataLoader 7 | from datasets.arrow_dataset import Dataset 8 | from datasets.metric import Metric 9 | import numpy as np 10 | import time 11 | import torch 12 | import torch.nn as nn 13 | from tqdm import tqdm 14 | import json 15 | import os 16 | 17 | def del_file(path): 18 | ls = os.listdir(path) 19 | for i in ls: 20 | c_path = os.path.join(path, i) 21 | os.remove(c_path) 22 | 23 | class EvalPrediction(NamedTuple): 24 | predictions: List[str] 25 | label_ids: np.ndarray 26 | metas: List[dict] 27 | 28 | 29 | class Seq2SeqTrainer(transformers.trainer_seq2seq.Seq2SeqTrainer): 30 | def __init__( 31 | self, 32 | metric: Metric, 33 | *args, 34 | eval_examples: Optional[Dataset] = None, 35 | ignore_pad_token_for_loss: bool = True, 36 | target_with_db_id: bool = False, 37 | use_decomposition: bool = False, 38 | stage: str = 'structure', 39 | training_method: str = 'FT', 40 | **kwargs, 41 | ) -> None: 42 | super().__init__(*args, **kwargs) 43 | self.metric = metric 44 | self.eval_examples = eval_examples 45 | self.compute_metrics = self._compute_metrics 46 | self.ignore_pad_token_for_loss = ignore_pad_token_for_loss 47 | self.target_with_db_id = target_with_db_id 48 | self.use_decomposition = use_decomposition 49 | self.stage = stage 50 | self.training_method = training_method 51 | self.best_acc = 0 52 | 53 | def _compute_metrics(self, eval_prediction: EvalPrediction) -> dict: 54 | #utils.spider.py 55 | raise NotImplementedError() 56 | 57 | def _post_process_function( 58 | self, examples: Dataset, features: Dataset, predictions: np.ndarray, stage: str 59 | ) -> EvalPrediction: 60 | #utils.spider.py 61 | raise NotImplementedError() 62 | 63 | 64 | def evaluate( 65 | self, 66 | eval_dataset: Optional[Dataset] = None, 67 | eval_examples: Optional[Dataset] = None, 68 | ignore_keys: Optional[List[str]] = None, 69 | metric_key_prefix: str = "eval", 70 | max_length: Optional[int] = None, 71 | max_time: Optional[int] = None, 72 | num_beams: Optional[int] = None, 73 | ) -> Dict[str, float]: 74 | self._max_length = max_length 75 | self._max_time = max_time 76 | self._num_beams = num_beams 77 | 78 | # memory metrics - must set up as early as possible 79 | self._memory_tracker.start() 80 | 81 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 82 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 83 | raise ValueError("eval_dataset must implement __len__") 84 | 85 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 86 | eval_examples = self.eval_examples if eval_examples is None else eval_examples 87 | start_time = time.time() 88 | 89 | print(f"epoch:{self.state.epoch}") 90 | if self.training_method == 'PT' and self.args.do_train: 91 | save_dir = self.args.output_dir 92 | os.makedirs(save_dir, exist_ok=True) 93 | 94 | del_file(save_dir) 95 | np.save(os.path.join(save_dir, f'head.npy'), self.model.prompt_head.cpu().detach().numpy()) 96 | np.save(os.path.join(save_dir, f'tail.npy'), self.model.prompt_tail.cpu().detach().numpy()) 97 | np.save(os.path.join(save_dir, f'mid1.npy'), self.model.prompt_mid1.cpu().detach().numpy()) 98 | np.save(os.path.join(save_dir, f'mid2.npy'), self.model.prompt_mid2.cpu().detach().numpy()) 99 | print(f"save embs to {save_dir}") 100 | 101 | # Temporarily disable metric computation, we will do it in the loop here. 102 | compute_metrics = self.compute_metrics 103 | self.compute_metrics = None 104 | try: 105 | output: PredictionOutput = self.evaluation_loop( 106 | eval_dataloader, 107 | description="Evaluation", 108 | # No point gathering the predictions if there are no metrics, otherwise we defer to 109 | # self.args.prediction_loss_only 110 | prediction_loss_only=True if compute_metrics is None else None, 111 | ignore_keys=ignore_keys, 112 | metric_key_prefix=metric_key_prefix, 113 | ) 114 | finally: 115 | self.compute_metrics = compute_metrics 116 | 117 | # We might have removed columns from the dataset so we put them back. 118 | if isinstance(eval_dataset, Dataset): 119 | eval_dataset.set_format( 120 | type=eval_dataset.format["type"], 121 | columns=list(eval_dataset.features.keys()), 122 | ) 123 | if eval_examples is not None and eval_dataset is not None and self.compute_metrics is not None: 124 | eval_preds = self._post_process_function( 125 | eval_examples, 126 | eval_dataset, 127 | output.predictions, 128 | "eval_{}".format(self.state.epoch), 129 | ) 130 | output.metrics.update(self.compute_metrics(eval_preds)) 131 | 132 | n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) 133 | output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples)) 134 | 135 | # Prefix all keys with metric_key_prefix + '_' 136 | for key in list(output.metrics.keys()): 137 | if not key.startswith(f"{metric_key_prefix}_"): 138 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) 139 | 140 | self.log(output.metrics) 141 | 142 | if output.metrics['eval_exact_match'] > self.best_acc and self.training_method != 'PT' and self.args.do_train: 143 | self.best_acc = output.metrics['eval_exact_match'] 144 | save_dir = os.path.join(self.args.output_dir, "BEST_MODEL") 145 | os.makedirs(save_dir, exist_ok=True) 146 | del_file(save_dir) 147 | print(f"save model to {save_dir} acc={output.metrics['eval_exact_match']}") 148 | if self.training_method == 'FT': 149 | state_dict = self.model.state_dict() 150 | self.model.save_pretrained(save_dir, state_dict=state_dict) 151 | self.tokenizer.save_pretrained(save_dir) 152 | torch.save(self.args, os.path.join(save_dir, "training_args.bin")) 153 | else: 154 | state_dict = self.model.model.state_dict() 155 | self.model.model.save_pretrained(save_dir, state_dict=state_dict) 156 | self.tokenizer.save_pretrained(save_dir) 157 | torch.save(self.args, os.path.join(save_dir, "training_args.bin")) 158 | np.save(f'{save_dir}/head.npy', self.model.prompt_head.cpu().detach().numpy()) 159 | np.save(f'{save_dir}/tail.npy', self.model.prompt_tail.cpu().detach().numpy()) 160 | np.save(f'{save_dir}/mid1.npy', self.model.prompt_mid1.cpu().detach().numpy()) 161 | np.save(f'{save_dir}/mid2.npy', self.model.prompt_mid2.cpu().detach().numpy()) 162 | 163 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 164 | 165 | self._memory_tracker.stop_and_update_metrics(output.metrics) 166 | if self.args.do_train and self.training_method != 'PT': 167 | print("BEST EM result:", self.best_acc) 168 | print("BEST model save at:", os.path.join(self.args.output_dir, "BEST_MODEL")) 169 | 170 | return output.metrics 171 | 172 | def predict( 173 | self, 174 | test_dataset: Dataset, 175 | test_examples: Dataset, 176 | ignore_keys: Optional[List[str]] = None, 177 | metric_key_prefix: str = "eval", 178 | max_length: Optional[int] = None, 179 | max_time: Optional[int] = None, 180 | num_beams: Optional[int] = None, 181 | ) -> PredictionOutput: 182 | self._max_length = max_length 183 | self._max_time = max_time 184 | self._num_beams = num_beams 185 | 186 | # memory metrics - must set up as early as possible 187 | self._memory_tracker.start() 188 | 189 | if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): 190 | raise ValueError("test_dataset must implement __len__") 191 | 192 | test_dataloader = self.get_test_dataloader(test_dataset) 193 | start_time = time.time() 194 | 195 | # Temporarily disable metric computation, we will do it in the loop here. 196 | compute_metrics = self.compute_metrics 197 | self.compute_metrics = None 198 | try: 199 | output: PredictionOutput = self.evaluation_loop( 200 | test_dataloader, 201 | description="Prediction", 202 | ignore_keys=ignore_keys, 203 | metric_key_prefix=metric_key_prefix, 204 | ) 205 | finally: 206 | self.compute_metrics = compute_metrics 207 | 208 | if self.compute_metrics is not None: 209 | # We might have removed columns from the dataset so we put them back. 210 | if isinstance(test_dataset, Dataset): 211 | test_dataset.set_format( 212 | type=test_dataset.format["type"], 213 | columns=list(test_dataset.features.keys()), 214 | ) 215 | 216 | eval_preds = self._post_process_function( 217 | test_examples, test_dataset, output.predictions, metric_key_prefix) 218 | output.metrics.update(self.compute_metrics(eval_preds)) 219 | 220 | output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset))) 221 | 222 | # Prefix all keys with metric_key_prefix + '_' 223 | for key in list(output.metrics.keys()): 224 | if not key.startswith(f"{metric_key_prefix}_"): 225 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) 226 | 227 | self.log(output.metrics) 228 | 229 | self._memory_tracker.stop_and_update_metrics(output.metrics) 230 | 231 | return output 232 | 233 | def prediction_step( 234 | self, 235 | model: nn.Module, 236 | inputs: Dict[str, Union[torch.Tensor, Any]], 237 | prediction_loss_only: bool, 238 | ignore_keys: Optional[List[str]] = None, 239 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 240 | """ 241 | Perform an evaluation step on `model` using `inputs`. 242 | 243 | Subclass and override to inject custom behavior. 244 | 245 | Args: 246 | model (`nn.Module`): 247 | The model to evaluate. 248 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 249 | The inputs and targets of the model. 250 | 251 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 252 | argument `labels`. Check your model's documentation for all accepted arguments. 253 | prediction_loss_only (`bool`): 254 | Whether or not to return the loss only. 255 | 256 | Return: 257 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 258 | labels (each being optional). 259 | """ 260 | 261 | if not self.args.predict_with_generate or prediction_loss_only: 262 | return super().prediction_step( 263 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 264 | ) 265 | 266 | has_labels = "labels" in inputs 267 | inputs = self._prepare_inputs(inputs) 268 | 269 | # XXX: adapt synced_gpus for fairscale as well 270 | gen_kwargs = { 271 | "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 272 | "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 273 | "synced_gpus": True if is_deepspeed_zero3_enabled() else False, 274 | "num_return_sequences": self.model.config.num_return_sequences, 275 | } 276 | 277 | if "attention_mask" in inputs: 278 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 279 | 280 | # prepare generation inputs 281 | # some encoder-decoder models can have varying encder's and thus 282 | # varying model input names 283 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 284 | generation_inputs = inputs[self.model.encoder.main_input_name] 285 | else: 286 | generation_inputs = inputs[self.model.main_input_name] 287 | input_strs = self.tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True) 288 | structures = [] 289 | for input_str in input_strs: 290 | start_index = 59 291 | end_index = input_str.index(' | question:') 292 | structure = input_str[start_index:end_index] 293 | start_index = input_str.index('| database: ')+12 294 | end_index = input_str.index('. table: ') 295 | db_name = input_str[start_index:end_index] 296 | structures.append(db_name+' | '+structure) 297 | generated_tokens = self.model.generate( 298 | generation_inputs, 299 | structures=structures, 300 | **gen_kwargs, 301 | ) 302 | # in case the batch is shorter than max length, the output should be padded 303 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 304 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 305 | 306 | if self.model.config.num_return_sequences > 1: 307 | preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 308 | save_dir = os.path.join('/app', self.args.output_dir, "hypotheses.json") 309 | try: 310 | with open(save_dir) as f: 311 | data = json.load(f) 312 | except: 313 | data = [] 314 | new_data = [] 315 | for idx in range(inputs['input_ids'].size(0)): 316 | new_data.append({'structure': structures[idx], 'topk_preds': []}) 317 | for idx, pred in enumerate(preds): 318 | new_data[idx//self.model.config.num_beams]['topk_preds'].append(pred) 319 | data.extend(new_data) 320 | if data != None: 321 | with open(save_dir, 'w') as f: 322 | json.dump( 323 | data, 324 | f, 325 | indent=4, 326 | ) 327 | pick_list = [] 328 | for idx in range(inputs['input_ids'].size(0)): 329 | pick_list.append(self.model.config.num_beams*idx) 330 | 331 | generated_tokens = torch.index_select(generated_tokens, 0, torch.tensor(pick_list).to(generated_tokens.device)) 332 | 333 | with torch.no_grad(): 334 | with self.autocast_smart_context_manager(): 335 | outputs = model(**inputs) 336 | if has_labels: 337 | if self.label_smoother is not None: 338 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 339 | else: 340 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 341 | else: 342 | loss = None 343 | 344 | if self.args.prediction_loss_only: 345 | return (loss, None, None) 346 | 347 | if has_labels: 348 | labels = inputs["labels"] 349 | if labels.shape[-1] < gen_kwargs["max_length"]: 350 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 351 | else: 352 | labels = None 353 | 354 | return (loss, generated_tokens, labels) --------------------------------------------------------------------------------