├── .gitignore ├── README.md ├── cleaned_eval_queries.jsonl.example ├── cleaned_train_queries.jsonl.example ├── eval_grpo.py ├── llm_train.py ├── requirements.txt └── sql_reward_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-to-SQL GRPO Fine-tuning Pipeline 2 | 3 | This repository contains a pipeline for fine-tuning Large Language Models (LLMs) for Text-to-SQL conversion using General Reward Proximal Optimization (GRPO). The implementation focuses on Qwen2.5-Coder models but can be adapted for other LLMs. 4 | 5 | ## Overview 6 | 7 | Text-to-SQL is the task of converting natural language questions into SQL queries. This project uses GRPO to fine-tune models, optimizing for: 8 | - SQL correctness 9 | - Clear reasoning 10 | - Proper formatting 11 | - Query complexity alignment 12 | 13 | ## Key Features 14 | 15 | - **GRPO Fine-tuning**: Optimize models with multiple reward functions 16 | - **Evaluation**: Comprehensive evaluation framework using gold queries and GPT-4o-mini 17 | - **SQL Reward Functions**: Multiple reward metrics for SQL quality assessment 18 | - **Contrastive Learning**: Improve natural language understanding for SQL generation 19 | 20 | ## Project Structure 21 | 22 | - `llm_train.py`: Main training script for GRPO fine-tuning 23 | - `sql_reward_utils.py`: SQL execution and reward functions 24 | - `eval_grpo.py`: Evaluation of fine-tuned models 25 | - `requirements.txt`: Required dependencies 26 | 27 | ## Installation 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ## Data Preparation 34 | 35 | 1. Clean the dataset: 36 | ```bash 37 | python cleanse_dataset.py 38 | ``` 39 | 40 | This script filters the dataset to ensure: 41 | - Valid SQL queries 42 | - Correctly matched schema contexts 43 | - Executable queries with proper syntax 44 | 45 | ## Training 46 | 47 | Run the GRPO training: 48 | 49 | ```bash 50 | python llm_train.py 51 | ``` 52 | 53 | Key parameters (can be modified in the script): 54 | - `MODEL_NAME`: Base model to fine-tune (default: "Qwen/Qwen2.5-Coder-7B-Instruct") 55 | - `MAX_SEQ_LENGTH`: Maximum sequence length (default: 1024) 56 | - `LORA_RANK`: LoRA rank for parameter-efficient fine-tuning (default: 32) 57 | - `BATCH_SIZE`: Training batch size (default: 4) 58 | - `NUM_GENERATIONS`: Number of generations per prompt for GRPO (default: 8) 59 | - `MAX_STEPS`: Maximum training steps (default: 225) 60 | 61 | ## Evaluation 62 | 63 | Evaluate your trained model: 64 | 65 | ```bash 66 | python eval_grpo.py 67 | ``` 68 | 69 | This script: 70 | 1. Loads your fine-tuned model 71 | 2. Generates SQL queries from test prompts 72 | 3. Evaluates the outputs using GPT-4o-mini 73 | 4. Produces detailed metrics and error analysis 74 | 5. Saves results as JSON and CSV 75 | 76 | ## Reward Functions 77 | 78 | The training uses multiple reward components: 79 | 80 | - **Format Reward**: Ensures proper XML tag structure 81 | - **SQL Correctness**: Tests executable accuracy against gold standard 82 | - **Complexity Reward**: Matches complexity between generated and gold queries 83 | - **Reasoning Quality**: Assesses explanation quality and schema references 84 | 85 | ## Model Outputs 86 | 87 | The model is trained to output in the following format: 88 | 89 | ``` 90 | 91 | This database has a users table with columns for id, name, and age. 92 | The question asks for all users over 30, so I need to query the users table with a WHERE condition. 93 | 94 | 95 | SELECT * FROM users WHERE age > 30; 96 | 97 | ``` 98 | -------------------------------------------------------------------------------- /cleaned_eval_queries.jsonl.example: -------------------------------------------------------------------------------- 1 | {"sql_prompt": "How many wells were drilled in the Gulf of Mexico before 2010, and what is the total amount of oil they produced?", "sql": "SELECT COUNT(*) as total_wells, SUM(production_oil) as total_oil_produced FROM gulf_of_mexico WHERE drill_date < '2010-01-01';", "sql_context": "CREATE TABLE gulf_of_mexico (id INT, well_name VARCHAR(255), drill_date DATE, production_oil INT);"} 2 | {"sql_prompt": "What is the total production from wells owned by 'PetroCanada'?", "sql": "SELECT SUM(production) FROM wells WHERE company = 'PetroCanada';", "sql_context": "CREATE TABLE wells (well_id INT, well_name VARCHAR(50), company VARCHAR(50), production FLOAT); INSERT INTO wells (well_id, well_name, company, production) VALUES (1, 'Well A', 'PetroCanada', 10000), (2, 'Well B', 'ExxonMobil', 15000), (3, 'Well C', 'PetroCanada', 20000);"} 3 | {"sql_prompt": "What was the total revenue from concert ticket sales in each city?", "sql": "SELECT location, SUM(revenue) FROM Concerts GROUP BY location;", "sql_context": "CREATE TABLE Concerts (location VARCHAR(50), revenue FLOAT); INSERT INTO Concerts (location, revenue) VALUES ('New York', 50000.00), ('Los Angeles', 75000.00), ('Chicago', 60000.00);"} 4 | {"sql_prompt": "Add new defense diplomacy event to 'defense_diplomacy' table", "sql": "INSERT INTO defense_diplomacy (id, event, country, date) VALUES (2, 'Joint Military Exercise', 'Japan', '2023-02-14');", "sql_context": "CREATE TABLE defense_diplomacy (id INT PRIMARY KEY, event VARCHAR(255), country VARCHAR(255), date DATE); INSERT INTO defense_diplomacy (id, event, country, date) VALUES (1, 'Military Attache', 'France', '2022-06-01');"} 5 | {"sql_prompt": "What is the average depth of all marine protected areas in the Antarctic region?", "sql": "SELECT AVG(avg_depth) FROM marine_protected_areas_antarctic WHERE region = 'Antarctic';", "sql_context": "CREATE TABLE marine_protected_areas_antarctic (name VARCHAR(255), region VARCHAR(255), avg_depth FLOAT); INSERT INTO marine_protected_areas_antarctic (name, region, avg_depth) VALUES ('Ross Sea', 'Antarctic', 150.0), ('Weddell Sea', 'Antarctic', 250.0);"} 6 | {"sql_prompt": "Rank the water treatment plants in India by the amount of wastewater treated daily in descending order.", "sql": "SELECT plant_name, daily_wastewater_treated, RANK() OVER (ORDER BY daily_wastewater_treated DESC) as rank FROM india_wastewater_treatment;", "sql_context": "CREATE TABLE india_wastewater_treatment (id INT, plant_name VARCHAR(50), daily_wastewater_treated FLOAT); INSERT INTO india_wastewater_treatment (id, plant_name, daily_wastewater_treated) VALUES (1, 'Bangalore Plant', 500), (2, 'Mumbai Plant', 600), (3, 'Delhi Plant', 400), (4, 'Chennai Plant', 450);"} 7 | {"sql_prompt": "Which Pacific Islander authors have published more than one book between 2000 and 2010?", "sql": "SELECT a.name FROM authors a INNER JOIN books b ON a.id = b.author_id GROUP BY a.name HAVING COUNT(b.id) > 1 AND MIN(b.publication_year) BETWEEN 2000 AND 2010;", "sql_context": "CREATE TABLE authors (id INT PRIMARY KEY, name VARCHAR(255), ethnicity VARCHAR(255)); INSERT INTO authors (id, name, ethnicity) VALUES (1, 'Alice Te Punga Somerville', 'Pacific Islander'); INSERT INTO authors (id, name, ethnicity) VALUES (2, 'Sia Figiel', 'Pacific Islander'); CREATE TABLE books (id INT PRIMARY KEY, title VARCHAR(255), author_id INT, publication_year INT); INSERT INTO books (id, title, author_id, publication_year) VALUES (1, 'Once Were Pacific', 1, 2009); INSERT INTO books (id, title, author_id, publication_year) VALUES (2, 'Two Dreams in a Row', 2, 2000); INSERT INTO books (id, title, author_id, publication_year) VALUES (3, 'The Girl in the Moon Circle', 2, 2008);"} 8 | {"sql_prompt": "What is the average citizen feedback score for parks?", "sql": "SELECT AVG(Score) FROM Feedback WHERE Service = 'Park';", "sql_context": "CREATE TABLE Feedback (Service VARCHAR(25), Score INT); INSERT INTO Feedback (Service, Score) VALUES ('Library', 8), ('Park', 7), ('Recreation Center', 9);"} 9 | {"sql_prompt": "Find the average guest_rating for hotel_id 5", "sql": "SELECT AVG(guest_rating) FROM hotel_reviews WHERE hotel_id = 5;", "sql_context": "CREATE TABLE hotel_reviews (hotel_id INT, guest_rating FLOAT, review_text TEXT);"} 10 | {"sql_prompt": "Delete the record of victim with id 2", "sql": "DELETE FROM victims WHERE id = 2;", "sql_context": "CREATE TABLE victims (id INT PRIMARY KEY, name VARCHAR(255), age INT, state VARCHAR(2));"} 11 | {"sql_prompt": "What are the names of the top 2 artists with the most songs in the 'song_details' table?", "sql": "SELECT artists.artist_name, COUNT(song_details.song_id) as song_count FROM artists INNER JOIN song_details ON artists.artist_id = song_details.artist_id GROUP BY artists.artist_id ORDER BY song_count DESC LIMIT 2;", "sql_context": "CREATE TABLE song_details (song_id INT, artist_id INT, genre VARCHAR(20)); INSERT INTO song_details (song_id, artist_id, genre) VALUES (1, 1, 'Pop'), (2, 2, 'Rock'), (3, 3, 'Jazz'), (4, 1, 'Pop'), (5, 2, 'Rock'), (6, 3, 'Jazz'), (7, 1, 'Pop'), (8, 2, 'Rock'); CREATE TABLE artists (artist_id INT, artist_name VARCHAR(50)); INSERT INTO artists (artist_id, artist_name) VALUES (1, 'Taylor Swift'), (2, 'BTS'), (3, 'Coldplay');"} 12 | {"sql_prompt": "Show the number of cultural heritage sites in each continent.", "sql": "SELECT continent, site_count FROM site_summary;", "sql_context": "CREATE TABLE heritage_sites (site_id INT, name VARCHAR(255), continent VARCHAR(255)); CREATE VIEW site_summary AS SELECT continent, COUNT(site_id) as site_count FROM heritage_sites GROUP BY continent;"} 13 | {"sql_prompt": "What is the average safety rating of sedans released since 2018?", "sql": "SELECT AVG(rating) FROM SafetyTesting WHERE vehicle_type = 'Sedan' AND release_year >= 2018;", "sql_context": "CREATE TABLE SafetyTesting (id INT, vehicle_type VARCHAR(50), rating INT, release_year INT); INSERT INTO SafetyTesting (id, vehicle_type, rating, release_year) VALUES (1, 'Sedan', 5, 2018), (2, 'Sedan', 5, 2019), (3, 'Sedan', 4, 2018), (4, 'Sedan', 5, 2020), (5, 'Sedan', 4, 2019), (6, 'Sedan', 4, 2021), (7, 'Sedan', 5, 2021);"} 14 | {"sql_prompt": "What is the age of each spy agency in each country?", "sql": "SELECT country, MAX(year_found) - MIN(year_found) AS age FROM SpyAgencies GROUP BY country;", "sql_context": "CREATE TABLE SpyAgencies (id INT PRIMARY KEY, name VARCHAR(50), country VARCHAR(50), year_found INT); INSERT INTO SpyAgencies (id, name, country, year_found) VALUES (1, 'CIA', 'USA', 1947); INSERT INTO SpyAgencies (id, name, country, year_found) VALUES (2, 'MI6', 'UK', 1909);"} 15 | {"sql_prompt": "What is the maximum range of military aircrafts in the 'military_tech' table for each country?", "sql": "SELECT country, MAX(range) as max_range FROM military_tech GROUP BY country;", "sql_context": "CREATE TABLE military_tech (country VARCHAR(50), aircraft_name VARCHAR(50), range INT); INSERT INTO military_tech (country, aircraft_name, range) VALUES ('USA', 'F-15', 3000), ('USA', 'F-22', 2960), ('Russia', 'Su-27', 3500), ('Russia', 'MiG-35', 2000), ('China', 'J-20', 2400);"} 16 | {"sql_prompt": "How many clinical trials did 'BioSolutions' conduct in 2019 and 2020?", "sql": "SELECT SUM(trials) FROM BioSolutions_ClinicalTrials WHERE company = 'BioSolutions' AND year IN (2019, 2020);", "sql_context": "CREATE TABLE BioSolutions_ClinicalTrials(company VARCHAR(20), year INT, trials INT); INSERT INTO BioSolutions_ClinicalTrials VALUES('BioSolutions', 2019, 12); INSERT INTO BioSolutions_ClinicalTrials VALUES('BioSolutions', 2020, 18);"} 17 | {"sql_prompt": "What is the average age of offenders in the justice system in New York?", "sql": "SELECT AVG(age) as avg_age FROM offenders WHERE state = 'New York';", "sql_context": "CREATE TABLE offenders (id INT, age INT, state TEXT); INSERT INTO offenders (id, age, state) VALUES (1, 25, 'New York'), (2, 30, 'New York'), (3, 35, 'California'), (4, 40, 'New York');"} 18 | {"sql_prompt": "Find the number of games won by the home_team for each city in the game_results table.", "sql": "SELECT city, home_team, COUNT(*) as num_wins FROM game_results WHERE home_score > away_score GROUP BY city, home_team;", "sql_context": "CREATE TABLE game_results (game_id INT, home_team VARCHAR(20), away_team VARCHAR(20), home_score INT, away_score INT, city VARCHAR(20), stadium VARCHAR(50));"} 19 | {"sql_prompt": "What is the total revenue generated from mobile and broadband services in the first quarter of 2021?", "sql": "SELECT SUM(revenue) FROM (SELECT revenue FROM mobile_revenue WHERE quarter = 1 UNION SELECT revenue FROM broadband_revenue WHERE quarter = 1);", "sql_context": "CREATE TABLE mobile_revenue(quarter INT, revenue FLOAT); INSERT INTO mobile_revenue(quarter, revenue) VALUES (1, 1500000), (2, 1800000), (3, 2250000), (4, 2500000); CREATE TABLE broadband_revenue(quarter INT, revenue FLOAT); INSERT INTO broadband_revenue(quarter, revenue) VALUES (1, 2000000), (2, 2750000), (3, 3250000), (4, 3500000);"} 20 | {"sql_prompt": "What was the total quantity of containers loaded on vessels per port for the month of January 2021?", "sql": "SELECT p.port_name, SUM(c.container_quantity) AS total_containers_loaded FROM containers c JOIN ports p ON c.port_id = p.port_id JOIN (SELECT DISTINCT vessel_id FROM containers WHERE load_date BETWEEN '2021-01-01' AND '2021-01-31') v ON c.vessel_id = v.vessel_id WHERE c.load_date BETWEEN '2021-01-01' AND '2021-01-31' GROUP BY p.port_name;", "sql_context": "CREATE TABLE ports (port_id INT, port_name VARCHAR(50));CREATE TABLE vessels (vessel_id INT, vessel_name VARCHAR(50));CREATE TABLE containers (container_id INT, container_quantity INT, port_id INT, vessel_id INT, load_date DATE); INSERT INTO ports VALUES (1, 'PortA'), (2, 'PortB'), (3, 'PortC'); INSERT INTO vessels VALUES (101, 'VesselX'), (102, 'VesselY'), (103, 'VesselZ');"} 21 | {"sql_prompt": "What is the average mass of all satellites in the \"satellite_mass\" table, grouped by launch year?", "sql": "SELECT launch_year, AVG(mass) AS avg_mass FROM satellite_mass GROUP BY launch_year;", "sql_context": "CREATE TABLE satellite_mass (id INT, satellite_name VARCHAR(50), manufacturer VARCHAR(50), mass FLOAT, launch_year INT); INSERT INTO satellite_mass (id, satellite_name, manufacturer, mass, launch_year) VALUES (1, 'Sat1', 'Manufacturer1', 1000, 2005); INSERT INTO satellite_mass (id, satellite_name, manufacturer, mass, launch_year) VALUES (2, 'Sat2', 'Manufacturer2', 2000, 2010);"} 22 | {"sql_prompt": "Add a new music genre 'ElectroLatin' into the music_genres table", "sql": "INSERT INTO music_genres (genre_name) VALUES ('ElectroLatin');", "sql_context": "CREATE TABLE music_genres (id INT, genre_name VARCHAR(50));"} 23 | {"sql_prompt": "Insert a new record for circular economy initiative in Berlin in 2022.", "sql": "INSERT INTO circular_economy (city, year, initiative) VALUES ('Berlin', 2022, 'Plastic waste reduction campaign');", "sql_context": "CREATE TABLE circular_economy (city VARCHAR(255), year INT, initiative VARCHAR(255));"} 24 | {"sql_prompt": "What is the distribution of customer sizes in France?", "sql": "SELECT customer_size, COUNT(*) AS customer_count FROM customer_sizes WHERE customer_country = 'France' GROUP BY customer_size", "sql_context": "CREATE TABLE customer_sizes (customer_id INT, customer_size TEXT, customer_country TEXT);"} 25 | {"sql_prompt": "How many decentralized finance (DeFi) dApps are currently running on the Binance Smart Chain?", "sql": "SELECT COUNT(*) FROM BinanceDApps WHERE network = 'Binance Smart Chain' AND category = 'DeFi' AND status = 'active';", "sql_context": "CREATE TABLE BinanceDApps (id INT, name VARCHAR(100), network VARCHAR(50), category VARCHAR(50), status VARCHAR(50)); INSERT INTO BinanceDApps (id, name, network, category, status) VALUES (1, 'PancakeSwap', 'Binance Smart Chain', 'DeFi', 'active'), (2, 'BakerySwap', 'Binance Smart Chain', 'DeFi', 'active'), (3, 'BurgerSwap', 'Binance Smart Chain', 'DeFi', 'inactive');"} 26 | {"sql_prompt": "Insert a new pop song released in 2022 into the Songs table", "sql": "INSERT INTO Songs (song_id, title, genre, release_date, price) VALUES (1001, 'New Pop Song', 'pop', '2022-10-15', 0.99);", "sql_context": "CREATE TABLE Songs (song_id INT, title TEXT, genre TEXT, release_date DATE, price DECIMAL(5,2));"} 27 | {"sql_prompt": "What is the earliest date an article about 'corruption' was published, for articles that have at least 1000 words?", "sql": "SELECT MIN(published_at) FROM articles WHERE articles.category = 'corruption' AND articles.word_count >= 1000;", "sql_context": "CREATE TABLE articles (id INT, title TEXT, category TEXT, published_at DATETIME, word_count INT);"} 28 | {"sql_prompt": "Display fan demographics, pivoted by gender", "sql": "SELECT age, location, interest, SUM(CASE WHEN gender = 'Male' THEN 1 ELSE 0 END) as males, SUM(CASE WHEN gender = 'Female' THEN 1 ELSE 0 END) as females FROM fan_demographics GROUP BY age, location, interest;", "sql_context": "CREATE TABLE fan_demographics (id INT, age INT, gender VARCHAR(50), location VARCHAR(50), interest VARCHAR(50));"} 29 | {"sql_prompt": "Which public services received the highest and lowest budget allocations in the city of Chicago in 2022?", "sql": "SELECT department, allocated_budget FROM city_budget WHERE city = 'Chicago' AND year = 2022 ORDER BY allocated_budget DESC, department ASC LIMIT 1; SELECT department, allocated_budget FROM city_budget WHERE city = 'Chicago' AND year = 2022 ORDER BY allocated_budget ASC, department ASC LIMIT 1;", "sql_context": "CREATE TABLE city_budget (city VARCHAR(255), year INT, department VARCHAR(255), allocated_budget FLOAT); INSERT INTO city_budget (city, year, department, allocated_budget) VALUES ('Chicago', 2022, 'Education', 5000000.00), ('Chicago', 2022, 'Police', 4000000.00), ('Chicago', 2022, 'Fire Department', 3000000.00);"} 30 | {"sql_prompt": "What is the average quantity of clothes sold in each size?", "sql": "SELECT size, AVG(quantity) FROM sales GROUP BY size;", "sql_context": "CREATE TABLE sales (id INT, product_id INT, size TEXT, quantity INT, sale_date DATE); INSERT INTO sales (id, product_id, size, quantity, sale_date) VALUES (1, 1001, 'XS', 25, '2021-09-01'), (2, 1002, 'XXL', 30, '2021-09-15'), (3, 1003, 'M', 40, '2021-09-20'), (4, 1004, 'L', 50, '2021-09-25');"} 31 | {"sql_prompt": "How many decentralized applications are on the Cardano platform?", "sql": "SELECT COUNT(*) FROM dapps WHERE platform = 'Cardano';", "sql_context": "CREATE TABLE dapps (dapp_id INT, name VARCHAR(100), platform VARCHAR(50)); INSERT INTO dapps (dapp_id, name, platform) VALUES (1, 'SingularityNET', 'Cardano'), (2, 'OccamFi', 'Cardano'), (3, 'Liqwid', 'Cardano');"} 32 | {"sql_prompt": "What is the total number of patients diagnosed with Diabetes in urban areas?", "sql": "SELECT COUNT(*) FROM Patients WHERE Diagnosis = 'Diabetes' AND Location = 'Urban';", "sql_context": "CREATE TABLE Patients (PatientID INT, Age INT, Gender VARCHAR(10), Diagnosis VARCHAR(20), Location VARCHAR(20)); INSERT INTO Patients (PatientID, Age, Gender, Diagnosis, Location) VALUES (1, 50, 'Male', 'Diabetes', 'Urban'); INSERT INTO Patients (PatientID, Age, Gender, Diagnosis, Location) VALUES (2, 55, 'Female', 'Diabetes', 'Urban'); INSERT INTO Patients (PatientID, Age, Gender, Diagnosis, Location) VALUES (3, 45, 'Male', 'Hypertension', 'Urban');"} 33 | {"sql_prompt": "Insert records for landfill_capacity table, with data for 'Argentina', 'China', 'Indonesia' and capacity values 9000, 16000, 21000 respectively", "sql": "INSERT INTO landfill_capacity (country, capacity) VALUES ('Argentina', 9000), ('China', 16000), ('Indonesia', 21000);", "sql_context": "CREATE TABLE landfill_capacity (country VARCHAR(50), capacity INT);"} 34 | {"sql_prompt": "What is the total number of streams for all songs by the artist \"Taylor Swift\" on the music streaming platform?", "sql": "SELECT SUM(streams) as total_streams FROM music_platform WHERE artist = 'Taylor Swift';", "sql_context": "CREATE TABLE music_platform (id INT, artist VARCHAR(100), song_title VARCHAR(100), streams INT);"} 35 | {"sql_prompt": "Which program type has the highest average attendance?", "sql": "SELECT type, AVG(attendance) FROM programs_attendance GROUP BY type ORDER BY AVG(attendance) DESC LIMIT 1;", "sql_context": "CREATE TABLE if not exists programs_attendance (id INT, name VARCHAR(255), type VARCHAR(255), attendance INT); INSERT INTO programs_attendance (id, name, type, attendance) VALUES (1, 'Story Time', 'Children', 300), (2, 'Art Class', 'Children', 250), (3, 'Theater Workshop', 'Youth', 150), (4, 'Jazz Night', 'Adults', 100), (5, 'Poetry Reading', 'Adults', 75);"} 36 | {"sql_prompt": "What is the total revenue of movies by release decade?", "sql": "SELECT (release_year - (release_year % 10)) AS decade, SUM(revenue) FROM movie_financials GROUP BY decade;", "sql_context": "CREATE TABLE movie_financials (title VARCHAR(255), revenue INT, release_year INT); INSERT INTO movie_financials (title, revenue, release_year) VALUES ('The Dark Knight', 1004, 2008), ('Star Wars', 775, 1977);"} 37 | {"sql_prompt": "What is the percentage of teachers who attended a professional development event in each region, broken down by gender?", "sql": "SELECT region, gender, 100.0 * AVG(event_attended) as percentage FROM teachers_gender GROUP BY region, gender;", "sql_context": "CREATE TABLE teachers_gender (teacher_id INT, region VARCHAR(20), event_attended BOOLEAN, gender VARCHAR(10)); INSERT INTO teachers_gender (teacher_id, region, event_attended, gender) VALUES (1, 'North', true, 'Female'), (2, 'North', false, 'Male'), (3, 'South', true, 'Female');"} 38 | {"sql_prompt": "What is the maximum and minimum height for trees in the 'tree_height' table?", "sql": "SELECT species, MAX(height) FROM tree_height;", "sql_context": "CREATE TABLE tree_height (id INT, species VARCHAR(255), height INT); INSERT INTO tree_height (id, species, height) VALUES (1, 'Oak', 80), (2, 'Maple', 70), (3, 'Pine', 60);"} 39 | {"sql_prompt": "What is the average length of articles about immigration, categorized by the author's country of origin?", "sql": "SELECT authors.country, AVG(articles.length) FROM articles INNER JOIN authors ON articles.author_id = authors.id WHERE articles.title LIKE '%immigration%' GROUP BY authors.country;", "sql_context": "CREATE TABLE articles (id INT, title VARCHAR(100), date DATE, length INT, author_id INT);CREATE TABLE authors (id INT, name VARCHAR(50), country VARCHAR(50)); INSERT INTO articles VALUES (1, 'Immigration crisis', '2022-02-01', 1000, 1); INSERT INTO authors VALUES (1, 'Jane Doe', 'USA');"} 40 | {"sql_prompt": "Update the sustainability_metrics table to reflect the latest energy consumption data for factory 3.", "sql": "UPDATE sustainability_metrics SET energy_consumption = 23000.0, measurement_date = CURRENT_DATE WHERE factory_id = 3;", "sql_context": "CREATE TABLE sustainability_metrics (factory_id INT, energy_consumption FLOAT, measurement_date DATE); INSERT INTO sustainability_metrics (factory_id, energy_consumption, measurement_date) VALUES (1, 25000.5, '2021-09-01'), (2, 18000.3, '2021-09-01'), (3, 22000.0, '2021-08-01');"} 41 | {"sql_prompt": "What is the minimum market price of Gadolinium in India over the past 3 years?", "sql": "SELECT MIN(market_price) FROM Gadolinium_Market_Prices WHERE country = 'India' AND year BETWEEN 2020 AND 2022;", "sql_context": "CREATE TABLE Gadolinium_Market_Prices (id INT, year INT, country VARCHAR(20), market_price DECIMAL(10,2));"} 42 | {"sql_prompt": "Which vehicle safety tests were passed by Toyota?", "sql": "SELECT DISTINCT Test FROM VehicleTesting WHERE Make = 'Toyota' AND Result = 'Pass';", "sql_context": "CREATE TABLE VehicleTesting (Id INT, Make VARCHAR(255), Model VARCHAR(255), Test VARCHAR(255), Result VARCHAR(255)); INSERT INTO VehicleTesting (Id, Make, Model, Test, Result) VALUES (3, 'Toyota', 'Corolla', 'AutoPilot', 'Pass');"} 43 | {"sql_prompt": "What is the total installed capacity (MW) of wind energy in each province of Canada?", "sql": "SELECT province, SUM(capacity) FROM canada_wind GROUP BY province;", "sql_context": "CREATE TABLE canada_wind (id INT, province VARCHAR(50), capacity FLOAT); INSERT INTO canada_wind (id, province, capacity) VALUES (1, 'Ontario', 500.5), (2, 'Quebec', 600.2), (3, 'Alberta', 800.1);"} 44 | {"sql_prompt": "How many specialists were added to the rural health clinic in Texas in Q1 2022?", "sql": "SELECT COUNT(*) FROM RuralClinic WHERE staff_type = 'specialist' AND hire_date BETWEEN '2022-01-01' AND '2022-03-31';", "sql_context": "CREATE TABLE RuralClinic (clinicID INT, staff_type VARCHAR(20), hire_date DATE); INSERT INTO RuralClinic (clinicID, staff_type, hire_date) VALUES (1, 'doctor', '2022-01-15'), (2, 'nurse', '2021-12-21'), (3, 'specialist', '2022-03-05');"} 45 | {"sql_prompt": "What are the explainable AI techniques used in the UK and Canada?", "sql": "SELECT DISTINCT location, technique FROM Explainable_AI WHERE location IN ('UK', 'Canada');", "sql_context": "CREATE TABLE Explainable_AI (id INT, technique TEXT, location TEXT); INSERT INTO Explainable_AI (id, technique, location) VALUES (1, 'SHAP', 'UK'), (2, 'LIME', 'Canada'), (3, 'anchors', 'UK'), (4, 'TreeExplainer', 'Canada');"} 46 | {"sql_prompt": "What is the count of patients with mental health disorders by their race/ethnicity?", "sql": "SELECT race_ethnicity, COUNT(*) as count FROM patients WHERE has_mental_health_disorder = true GROUP BY race_ethnicity;", "sql_context": "CREATE TABLE patients (id INT, has_mental_health_disorder BOOLEAN, race_ethnicity VARCHAR(50)); INSERT INTO patients (id, has_mental_health_disorder, race_ethnicity) VALUES (1, true, 'Asian'), (2, false, 'White'), (3, true, 'Hispanic'), (4, true, 'Black');"} 47 | {"sql_prompt": "Who are the top 3 countries receiving climate finance for communication projects?", "sql": "SELECT cf.country, SUM(cf.amount) FROM climate_finance cf WHERE cf.sector = 'communication' GROUP BY cf.country ORDER BY SUM(cf.amount) DESC LIMIT 3;", "sql_context": "CREATE TABLE climate_finance (id INT, country VARCHAR(50), amount FLOAT, sector VARCHAR(50));"} 48 | {"sql_prompt": "How many ethical AI initiatives were implemented in each region, ordered by the number of initiatives in descending order?", "sql": "SELECT region, COUNT(*) as total_initiatives FROM ethical_ai_initiatives GROUP BY region ORDER BY total_initiatives DESC;", "sql_context": "CREATE TABLE ethical_ai_initiatives (initiative_id INT, initiative_name VARCHAR(255), region VARCHAR(255)); INSERT INTO ethical_ai_initiatives (initiative_id, initiative_name, region) VALUES (1, 'AI for social justice', 'North America'), (2, 'Ethical AI guidelines', 'Europe'), (3, 'AI for disability', 'Asia'), (4, 'AI for healthcare equality', 'Africa'), (5, 'Fair AI in education', 'South America'), (6, 'Ethical AI for finance', 'North America'), (7, 'AI for environmental justice', 'Europe');"} 49 | {"sql_prompt": "What is the total number of eco-friendly hotels in London and Paris combined?", "sql": "SELECT COUNT(*) FROM eco_hotels WHERE city IN ('London', 'Paris');", "sql_context": "CREATE TABLE eco_hotels (hotel_id INT, name TEXT, city TEXT); INSERT INTO eco_hotels (hotel_id, name, city) VALUES (1, 'Eco Hotel London', 'London'), (2, 'Green Haven London', 'London'), (3, 'Eco Lodge London', 'London'), (4, 'Eco Hotel Paris', 'Paris'), (5, 'Green Haven Paris', 'Paris'), (6, 'Eco Lodge Paris', 'Paris');"} 50 | {"sql_prompt": "What is the average delivery time for each courier?", "sql": "SELECT c.courier, AVG(d.delivery_time) as avg_delivery_time FROM deliveries d JOIN couriers c ON d.courier_id = c.courier_id GROUP BY c.courier;", "sql_context": "CREATE TABLE couriers (courier_id INT, courier TEXT); INSERT INTO couriers (courier_id, courier) VALUES (1, 'DHL'), (2, 'UPS'), (3, 'FedEx'); CREATE TABLE deliveries (delivery_id INT, courier_id INT, delivery_time INT); INSERT INTO deliveries (delivery_id, courier_id, delivery_time) VALUES (1, 1, 500), (2, 2, 700), (3, 3, 400), (4, 1, 600), (5, 3, 300);"} 51 | {"sql_prompt": "What is the average duration of space missions for US astronauts?", "sql": "SELECT Nationality, AVG(MissionDuration) FROM Astronauts WHERE Nationality = 'American' GROUP BY Nationality;", "sql_context": "CREATE TABLE Astronauts (AstronautID INT, FirstName VARCHAR(20), LastName VARCHAR(20), Nationality VARCHAR(20), SpaceMissions INT, MissionDuration INT); INSERT INTO Astronauts (AstronautID, FirstName, LastName, Nationality, SpaceMissions, MissionDuration) VALUES (1, 'Alan', 'Shepard', 'American', 2, 315); INSERT INTO Astronauts (AstronautID, FirstName, LastName, Nationality, SpaceMissions, MissionDuration) VALUES (2, 'Mae', 'Jemison', 'American', 1, 190);"} 52 | {"sql_prompt": "What is the total revenue for vegetarian menu items?", "sql": "SELECT SUM(o.revenue) as total_revenue FROM orders o JOIN menu_items mi ON o.menu_item_id = mi.id WHERE mi.vegetarian = TRUE;", "sql_context": "CREATE TABLE restaurants (id INT, name VARCHAR(255)); INSERT INTO restaurants (id, name) VALUES (1, 'Restaurant A'), (2, 'Restaurant B'), (3, 'Restaurant C'); CREATE TABLE menu_items (id INT, name VARCHAR(255), vegetarian BOOLEAN, restaurant_id INT); INSERT INTO menu_items (id, name, vegetarian, restaurant_id) VALUES (1, 'Tacos', FALSE, 1), (2, 'Pizza', TRUE, 2), (3, 'Fried Rice', FALSE, 3), (4, 'Burrito', TRUE, 1), (5, 'Spaghetti', FALSE, 2); CREATE TABLE orders (menu_item_id INT, revenue INT); INSERT INTO orders (menu_item_id, revenue) VALUES (1, 500), (2, 700), (3, 600), (4, 800), (5, 900);"} 53 | {"sql_prompt": "What is the total waste production by sustainable material category in 2019?", "sql": "SELECT category, SUM(quantity) as total_waste FROM waste_data WHERE year = 2019 GROUP BY category;", "sql_context": "CREATE TABLE waste_data (year INT, category VARCHAR(255), quantity INT); INSERT INTO waste_data (year, category, quantity) VALUES (2018, 'Organic Cotton', 1000), (2018, 'Recycled Polyester', 1500), (2018, 'Hemp', 500), (2019, 'Organic Cotton', 1200), (2019, 'Recycled Polyester', 1800), (2019, 'Hemp', 600);"} 54 | {"sql_prompt": "What was the total quantity of Samarium (Sm) supplied by each supplier in Q2 2022, ordered by supplier name?", "sql": "SELECT supplier, SUM(quantity) AS total_quantity FROM supplier_trends WHERE element = 'Sm' AND quarter = 2 AND year = 2022 GROUP BY supplier ORDER BY supplier;", "sql_context": "CREATE TABLE supplier_trends (supplier VARCHAR(25), element VARCHAR(2), quantity INT, quarter INT, year INT); INSERT INTO supplier_trends VALUES ('SupplierD', 'Sm', 400, 2, 2022), ('SupplierE', 'Sm', 600, 2, 2022), ('SupplierF', 'Sm', 500, 2, 2022);"} 55 | {"sql_prompt": "What is the total quantity of gold and silver extracted by each location?", "sql": "SELECT location, SUM(CASE WHEN mineral = 'Gold' THEN quantity ELSE 0 END) as total_gold, SUM(CASE WHEN mineral = 'Silver' THEN quantity ELSE 0 END) as total_silver FROM geological_survey GROUP BY location;", "sql_context": "CREATE TABLE geological_survey (location VARCHAR(255), mineral VARCHAR(255), quantity FLOAT, year INT); INSERT INTO geological_survey (location, mineral, quantity, year) VALUES ('Mine A', 'Gold', 1000, 2015), ('Mine A', 'Silver', 2000, 2015), ('Mine B', 'Gold', 1500, 2016), ('Mine B', 'Silver', 2500, 2016);"} 56 | {"sql_prompt": "What is the maximum capacity of any energy efficiency project?", "sql": "SELECT MAX(capacity) FROM energy_efficiency_projects;", "sql_context": "CREATE TABLE energy_efficiency_projects (name TEXT, capacity INTEGER); INSERT INTO energy_efficiency_projects (name, capacity) VALUES ('Project A', 200), ('Project B', 900);"} 57 | {"sql_prompt": "Set the nitrogen level to 150 for the crop with ID C005", "sql": "UPDATE crops SET nitrogen_level = 150 WHERE crop_id = 'C005';", "sql_context": "CREATE TABLE crops (crop_id VARCHAR(10), nitrogen_level INT);"} 58 | {"sql_prompt": "What are the open pedagogy initiatives and their corresponding budgets?", "sql": "SELECT i.initiative_name, b.budget_amount FROM initiatives i INNER JOIN budgets b ON i.initiative_id = b.initiative_id;", "sql_context": "CREATE TABLE initiatives (initiative_id INT, initiative_name VARCHAR(50), initiative_type VARCHAR(50)); CREATE TABLE budgets (budget_id INT, initiative_id INT, budget_amount INT); INSERT INTO initiatives (initiative_id, initiative_name, initiative_type) VALUES (101, 'Open Source Textbooks', 'Open Pedagogy'), (102, 'Peer Learning Networks', 'Open Pedagogy'), (103, 'Project-Based Learning', 'Open Pedagogy'); INSERT INTO budgets (budget_id, initiative_id, budget_amount) VALUES (201, 101, 10000), (202, 102, 15000), (203, 103, 12000);"} 59 | {"sql_prompt": "Delete all records related to wheelchair ramps from the accommodations table.", "sql": "DELETE FROM accommodations WHERE type = 'Wheelchair Ramp';", "sql_context": "CREATE TABLE accommodations (id INT, type VARCHAR(255), description VARCHAR(255)); INSERT INTO accommodations (id, type, description) VALUES (1, 'Wheelchair Ramp', 'Ramp with handrails and non-slip surface'); INSERT INTO accommodations (id, type, description) VALUES (2, 'Elevator', 'Standard elevator for building access');"} 60 | {"sql_prompt": "Determine the minimum salary for employees who have completed the advanced leadership training.", "sql": "SELECT MIN(Salary) FROM Employees WHERE CompletedAdvancedLeadershipTraining = TRUE;", "sql_context": "CREATE TABLE Employees (EmployeeID INT, CompletedAdvancedLeadershipTraining BOOLEAN, Salary FLOAT);"} 61 | {"sql_prompt": "How many community health centers offer telehealth services in California?", "sql": "SELECT COUNT(*) FROM TelehealthServices WHERE State = 'California' AND Telehealth = 'Yes';", "sql_context": "CREATE TABLE TelehealthServices (HealthCenterID INT, State VARCHAR(20), Telehealth VARCHAR(10)); INSERT INTO TelehealthServices (HealthCenterID, State, Telehealth) VALUES (1, 'California', 'Yes'); INSERT INTO TelehealthServices (HealthCenterID, State, Telehealth) VALUES (2, 'California', 'No');"} 62 | {"sql_prompt": "What is the minimum depth of the ocean floor in the Mariana trench?", "sql": "SELECT MIN(depth) FROM ocean_floor_depth WHERE location = 'Mariana Trench';", "sql_context": "CREATE TABLE ocean_floor_depth (location VARCHAR(255), depth FLOAT); INSERT INTO ocean_floor_depth (location, depth) VALUES ('Mariana Trench', 10994), ('Puerto Rico Trench', 8605);"} 63 | {"sql_prompt": "Find the average age of visitors who participated in online events in Argentina.", "sql": "SELECT AVG(participant_age) FROM EventParticipants WHERE country = 'Argentina' AND event_type = 'Online';", "sql_context": "CREATE TABLE EventParticipants (event_id INT, country VARCHAR(20), participant_age INT, event_type VARCHAR(10)); INSERT INTO EventParticipants (event_id, country, participant_age, event_type) VALUES (1, 'Argentina', 25, 'Online'), (2, 'Brazil', 30, 'Offline'), (3, 'Chile', 35, 'Offline');"} 64 | {"sql_prompt": "Update the age of archaeologist 'John Doe' in the 'archaeologists' table to 35.", "sql": "UPDATE archaeologists SET age = 35 WHERE name = 'John Doe';", "sql_context": "CREATE TABLE archaeologists (id INT, name VARCHAR(50), age INT, gender VARCHAR(10), country VARCHAR(50));"} 65 | {"sql_prompt": "Find the number of tourists from Brazil in each destination in 2022?", "sql": "SELECT destination, tourists FROM tourism_stats WHERE visitor_country = 'Brazil' AND year = 2022;", "sql_context": "CREATE TABLE tourism_stats (visitor_country VARCHAR(20), destination VARCHAR(20), year INT, tourists INT); INSERT INTO tourism_stats (visitor_country, destination, year, tourists) VALUES ('Brazil', 'Rio de Janeiro', 2022, 700), ('Brazil', 'Sao Paulo', 2022, 800), ('Brazil', 'Brasilia', 2022, 600);"} 66 | {"sql_prompt": "What is the name of the faculty member with the least number of publications in the Computer Science department?", "sql": "SELECT Name FROM Faculty WHERE Department = 'Computer Science' AND NumPublications = (SELECT MIN(NumPublications) FROM Faculty WHERE Department = 'Computer Science');", "sql_context": "CREATE TABLE Faculty (FacultyID int, Name varchar(50), Department varchar(50), NumPublications int); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (1, 'John Doe', 'Mathematics', 15); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (2, 'Jane Smith', 'Mathematics', 20); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (3, 'Mary Johnson', 'Physics', 25); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (4, 'Bob Brown', 'Physics', 10); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (5, 'Alice Davis', 'Computer Science', 5); INSERT INTO Faculty (FacultyID, Name, Department, NumPublications) VALUES (6, 'Charlie Brown', 'Computer Science', 10);"} 67 | {"sql_prompt": "Determine the revenue generated from the sale of dishes in a specific region in the last month.", "sql": "SELECT SUM(i.quantity * m.price) as revenue FROM inventory i JOIN orders o ON i.item_id = o.item_id JOIN menu_items m ON i.item_id = m.item_id JOIN restaurants r ON o.restaurant_id = r.restaurant_id WHERE o.order_date BETWEEN '2022-02-01' AND '2022-02-28' AND r.region = 'Midwest';", "sql_context": "CREATE TABLE inventory (item_id INT, quantity INT, unit_price DECIMAL(5,2)); INSERT INTO inventory (item_id, quantity, unit_price) VALUES (1, 10, 12.99), (2, 20, 7.50), (3, 30, 9.99), (4, 40, 15.49), (5, 50, 8.99); CREATE TABLE orders (order_id INT, item_id INT, order_date DATE, restaurant_id INT); INSERT INTO orders (order_id, item_id, order_date, restaurant_id) VALUES (1, 1, '2022-03-01', 2), (2, 3, '2022-03-02', 2), (3, 2, '2022-03-03', 1), (4, 4, '2022-03-04', 1), (5, 5, '2022-03-05', 2); CREATE TABLE menu_items (item_id INT, name TEXT, is_vegan BOOLEAN, price DECIMAL(5,2)); INSERT INTO menu_items (item_id, name, is_vegan, price) VALUES (1, 'Quinoa Salad', true, 12.99), (2, 'Beef Burger', false, 7.50), (3, 'Chickpea Curry', true, 9.99), (4, 'Cheesecake', false, 15.49), (5, 'Veggie Pizza', true, 8.99); CREATE TABLE restaurants (restaurant_id INT, name TEXT, region TEXT); INSERT INTO restaurants (restaurant_id, name, region) VALUES (1, 'Big Burger', 'East'), (2, 'Veggies R Us', 'Midwest'), (3, 'Tasty Bites', 'West');"} 68 | {"sql_prompt": "Which brands have received the most cruelty-free certifications?", "sql": "SELECT b.Brand_Name, COUNT(c.Certification_ID) AS Cruelty_Free_Certifications_Count FROM Brands b JOIN Certifications c ON b.Brand_ID = c.Brand_ID GROUP BY b.Brand_ID;", "sql_context": "CREATE TABLE Brands (Brand_ID INT PRIMARY KEY, Brand_Name TEXT); CREATE TABLE Certifications (Certification_ID INT PRIMARY KEY, Certification_Name TEXT, Brand_ID INT); INSERT INTO Brands (Brand_ID, Brand_Name) VALUES (1, 'Ethical Beauty'), (2, 'Pure Cosmetics'), (3, 'Green Earth'), (4, 'Eco Living'), (5, 'Sustainable Solutions'); INSERT INTO Certifications (Certification_ID, Certification_Name, Brand_ID) VALUES (1, 'Leaping Bunny', 1), (2, 'Cruelty Free International', 1), (3, 'People for the Ethical Treatment of Animals (PETA)', 2), (4, 'Leaping Bunny', 2), (5, 'Cruelty Free International', 3), (6, 'Leaping Bunny', 3), (7, 'People for the Ethical Treatment of Animals (PETA)', 4), (8, 'Leaping Bunny', 4), (9, 'Cruelty Free International', 5), (10, 'People for the Ethical Treatment of Animals (PETA)', 5);"} 69 | {"sql_prompt": "What is the total biomass of marine species in the Southern Ocean, excluding species that are not yet classified?", "sql": "SELECT SUM(ms.biomass) as total_biomass FROM marine_species ms WHERE ms.region = 'Southern Ocean' AND ms.classified = TRUE;", "sql_context": "CREATE TABLE marine_species (id INT, name VARCHAR(100), region VARCHAR(50), biomass FLOAT, classified BOOLEAN);"} 70 | {"sql_prompt": "What is the maximum water temperature recorded for each farm?", "sql": "SELECT FarmID, MAX(WaterTemp) as MaxTemp FROM FarmTemp GROUP BY FarmID;", "sql_context": "CREATE TABLE FarmTemp (FarmID int, Date date, WaterTemp float); INSERT INTO FarmTemp (FarmID, Date, WaterTemp) VALUES (1, '2022-01-01', 10.5), (1, '2022-01-02', 11.2), (2, '2022-01-01', 12.1), (2, '2022-01-02', 12.6);"} 71 | {"sql_prompt": "Update the crime_statistics table to mark the 'status' column as 'Solved' for records with 'crime_type' 'Burglary' and 'date' '2022-06-14'?", "sql": "UPDATE crime_statistics SET status = 'Solved' WHERE crime_type = 'Burglary' AND date = '2022-06-14';", "sql_context": "CREATE TABLE crime_statistics (crime_type VARCHAR(255), crime_count INT, date DATE, status VARCHAR(255)); INSERT INTO crime_statistics (crime_type, crime_count, date, status) VALUES (NULL, NULL, NULL, NULL);"} 72 | {"sql_prompt": "What are the total sales of military equipment to Asia in Q1 2020?", "sql": "SELECT SUM(sales) FROM sales WHERE region = 'Asia' AND quarter = 1 AND year = 2020;", "sql_context": "CREATE TABLE sales(id INT, region VARCHAR(255), quarter INT, year INT, equipment VARCHAR(255), sales FLOAT);"} 73 | {"sql_prompt": "What was the total amount of funds allocated for agricultural innovation projects in Uganda in 2021?", "sql": "SELECT SUM(funds) FROM agricultural_innovation_projects WHERE country = 'Uganda' AND year = 2021;", "sql_context": "CREATE TABLE agricultural_innovation_projects (id INT, country VARCHAR(50), funds FLOAT, year INT); INSERT INTO agricultural_innovation_projects (id, country, funds, year) VALUES (1, 'Kenya', 500000.00, 2020), (2, 'Uganda', 600000.00, 2021);"} 74 | {"sql_prompt": "Insert a new treatment record into the treatments table.", "sql": "INSERT INTO treatments (id, patient_id, therapist_id, treatment_date, treatment_type, cost) VALUES (2, 3, 2, '2022-02-10', 'CBT', 200.00);", "sql_context": "CREATE TABLE treatments (id INT PRIMARY KEY, patient_id INT, therapist_id INT, treatment_date DATE, treatment_type VARCHAR(50), cost DECIMAL(5,2));INSERT INTO treatments (id, patient_id, therapist_id, treatment_date, treatment_type, cost) VALUES (1, 1, 1, '2022-01-01', 'Psychotherapy', 150.00);"} 75 | {"sql_prompt": "Find the top 3 drugs with the highest R&D expenditures for a given year.", "sql": "SELECT drug_name, SUM(amount) OVER (PARTITION BY drug_name ORDER BY SUM(amount) DESC) AS total_expenditure, year FROM rd_expenditures GROUP BY 1, 3 ORDER BY 2 DESC LIMIT 3;", "sql_context": "CREATE TABLE rd_expenditures (drug_name TEXT, year INTEGER, amount INTEGER);"} 76 | {"sql_prompt": "How many patients with mental health issues are being treated by cultural competency trained providers?", "sql": "SELECT COUNT(MentalHealthPatients.PatientID) FROM MentalHealthPatients INNER JOIN Providers ON MentalHealthPatients.PatientID = Providers.ProviderID WHERE Providers.CulturalCompetencyTraining IS NOT NULL;", "sql_context": "CREATE TABLE MentalHealthPatients (PatientID INT, Age INT, MentalHealthIssue VARCHAR(20)); CREATE TABLE Providers (ProviderID INT, ProviderName VARCHAR(20), CulturalCompetencyTraining DATE); INSERT INTO MentalHealthPatients (PatientID, Age, MentalHealthIssue) VALUES (1, 30, 'Anxiety'); INSERT INTO Providers (ProviderID, ProviderName, CulturalCompetencyTraining) VALUES (1, 'Dr. Smith', '2022-01-01');"} 77 | {"sql_prompt": "What is the total revenue for each cultural event category, sorted by revenue in descending order?", "sql": "SELECT category, SUM(price * quantity) as total_revenue FROM tickets GROUP BY category ORDER BY total_revenue DESC;", "sql_context": "CREATE TABLE tickets (id INT, event TEXT, category TEXT, price DECIMAL(5,2), quantity INT); INSERT INTO tickets (id, event, category, price, quantity) VALUES (1, 'Concert', 'music', 50.00, 100), (2, 'Jazz Festival', 'music', 35.00, 200), (3, 'Theatre Play', 'theatre', 75.00, 150);"} 78 | {"sql_prompt": "How many renewable energy facilities are located in the Asia-Pacific region, and what is their total capacity in MW?", "sql": "SELECT region, SUM(capacity) as total_capacity FROM renewable_facilities WHERE region = 'Asia-Pacific' GROUP BY region;", "sql_context": "CREATE TABLE renewable_facilities (region VARCHAR(50), capacity NUMERIC, technology VARCHAR(50)); INSERT INTO renewable_facilities (region, capacity, technology) VALUES ('Asia-Pacific', 500, 'Solar'), ('Asia-Pacific', 600, 'Wind'), ('Europe', 400, 'Hydro'), ('Africa', 300, 'Geothermal');"} 79 | {"sql_prompt": "What is the average salary of workers in the 'manufacturing' department across all companies?", "sql": "SELECT AVG(worker_salaries.salary) FROM worker_salaries INNER JOIN companies ON worker_salaries.company_id = companies.company_id WHERE companies.department = 'manufacturing';", "sql_context": "CREATE TABLE companies (company_id INT, department VARCHAR(20));CREATE TABLE worker_salaries (worker_id INT, company_id INT, salary INT, department VARCHAR(20));"} 80 | {"sql_prompt": "What is the total number of policy impact records for 'City N' and 'City O'?", "sql": "SELECT COUNT(*) FROM policy_impact WHERE (city = 'City N' OR city = 'City O')", "sql_context": "CREATE TABLE policy_impact (city VARCHAR(255), policy_id INT, impact TEXT); INSERT INTO policy_impact"} 81 | {"sql_prompt": "Find the most recent mission for each spacecraft that had a female commander.", "sql": "SELECT spacecraft_id, MAX(mission_date) as most_recent_mission FROM Spacecraft_Missions WHERE commander_gender = 'Female' GROUP BY spacecraft_id", "sql_context": "CREATE TABLE Spacecraft_Missions (id INT, spacecraft_id INT, mission_name VARCHAR(100), mission_date DATE, commander_gender VARCHAR(10)); INSERT INTO Spacecraft_Missions (id, spacecraft_id, mission_name, mission_date, commander_gender) VALUES (1, 1, 'Apollo 11', '1969-07-16', 'Male');"} 82 | {"sql_prompt": "What is the average age of football players in the 'players' table?", "sql": "SELECT AVG(age) FROM players WHERE position = 'Goalkeeper' OR position = 'Defender';", "sql_context": "CREATE TABLE players (player_id INT, name VARCHAR(50), position VARCHAR(50), team VARCHAR(50), age INT); INSERT INTO players (player_id, name, position, team, age) VALUES (1, 'John Doe', 'Goalkeeper', 'Arsenal', 32), (2, 'Jane Smith', 'Defender', 'Manchester United', 28);"} 83 | {"sql_prompt": "What is the total installed capacity of renewable energy projects, broken down by country and project type?", "sql": "SELECT country, project_type, SUM(installed_capacity) FROM renewable_energy GROUP BY country, project_type;", "sql_context": "CREATE TABLE renewable_energy (country VARCHAR(50), project_type VARCHAR(50), installed_capacity INT); INSERT INTO renewable_energy (country, project_type, installed_capacity) VALUES ('USA', 'Wind', 3000), ('USA', 'Solar', 5000), ('Mexico', 'Wind', 2000), ('Mexico', 'Solar', 4000);"} 84 | {"sql_prompt": "What is the total number of workers in the manufacturing industry across all regions?", "sql": "SELECT SUM(number_of_workers) FROM manufacturing_industry;", "sql_context": "CREATE TABLE manufacturing_industry (id INT, region VARCHAR(255), number_of_workers INT); INSERT INTO manufacturing_industry (id, region, number_of_workers) VALUES (1, 'North', 5000), (2, 'South', 7000), (3, 'East', 6000), (4, 'West', 8000);"} 85 | {"sql_prompt": "What is the percentage of open data initiatives in the United States with public participation?", "sql": "SELECT (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM open_data_initiatives WHERE initiative_country = 'United States')) AS percentage FROM open_data_initiatives WHERE initiative_country = 'United States' AND has_public_participation = true;", "sql_context": "CREATE TABLE open_data_initiatives (initiative_id INT, initiative_date DATE, initiative_country VARCHAR(50), has_public_participation BOOLEAN); INSERT INTO open_data_initiatives (initiative_id, initiative_date, initiative_country, has_public_participation) VALUES (1, '2021-01-01', 'United States', true);"} 86 | {"sql_prompt": "What is the average CO2 emission for virtual tours?", "sql": "SELECT AVG(virtual_tour.co2_eq) FROM virtual_tour;", "sql_context": "CREATE TABLE virtual_tour (id INT, co2_eq DECIMAL(10, 2), tour_id INT); INSERT INTO virtual_tour (id, co2_eq, tour_id) VALUES (1, 2.5, 1);"} 87 | {"sql_prompt": "What is the maximum environmental impact score for mining operations in Year 2000?", "sql": "SELECT MAX(score) FROM environmental_impact WHERE year = 2000 AND mining_operation LIKE '%Mining%';", "sql_context": "CREATE TABLE environmental_impact (id INT, mining_operation TEXT, year INT, score FLOAT); INSERT INTO environmental_impact (id, mining_operation, year, score) VALUES (1, 'Operation A', 1999, 45.6); INSERT INTO environmental_impact (id, mining_operation, year, score) VALUES (2, 'Operation B', 2000, 67.8);"} 88 | {"sql_prompt": "What is the total donation amount per volunteer for the year 2022?", "sql": "SELECT volunteer, SUM(donation) FROM donations WHERE donation_date BETWEEN '2022-01-01' AND '2022-12-31' GROUP BY volunteer;", "sql_context": "CREATE TABLE donations (id INT, volunteer TEXT, donation FLOAT, donation_date DATE); INSERT INTO donations (id, volunteer, donation, donation_date) VALUES (1, 'Jamal Williams', 50.00, '2022-01-01'), (2, 'Sophia Garcia', 100.00, '2022-02-01'), (3, 'Liam Brown', 25.00, '2022-01-15'), (4, 'Olivia Johnson', 75.00, '2022-03-05'), (5, 'William Davis', 125.00, '2022-03-20');"} 89 | {"sql_prompt": "What are the names and total transactions of all Shariah-compliant banks in the Asia region?", "sql": "SELECT bank_name, SUM(total_transactions) FROM shariah_banks WHERE region = 'Asia' GROUP BY bank_name;", "sql_context": "CREATE TABLE shariah_banks (bank_name TEXT, region TEXT, total_transactions INTEGER); INSERT INTO shariah_banks (bank_name, region, total_transactions) VALUES ('Al Rajhi Bank', 'Asia', 50000), ('Bank Islam', 'Asia', 45000), ('Maybank Islamic', 'Asia', 40000);"} 90 | {"sql_prompt": "Insert a new employee record for 'Bruce Wayne' from 'UK' into the 'employees' table.", "sql": "INSERT INTO employees (id, name, country) VALUES (4, 'Bruce Wayne', 'UK');", "sql_context": "CREATE TABLE employees (id INT, name VARCHAR(255), country VARCHAR(255)); INSERT INTO employees (id, name, country) VALUES (1, 'John Doe', 'USA'); INSERT INTO employees (id, name, country) VALUES (2, 'Jane Smith', 'Canada');"} 91 | {"sql_prompt": "Identify cities with a population under 100,000 and a recycling rate below 10%.", "sql": "SELECT name FROM cities_2 WHERE population < 100000 AND recycling_rate < 10;", "sql_context": "CREATE TABLE cities_2 (name VARCHAR(255), state VARCHAR(255), population DECIMAL(10,2), recycling_rate DECIMAL(5,2)); INSERT INTO cities_2 (name, state, population, recycling_rate) VALUES ('City1', 'California', 50000, 15), ('City2', 'Texas', 75000, 8), ('City3', 'Florida', 90000, 25);"} 92 | {"sql_prompt": "Which space agencies have launched satellites with the most debris?", "sql": "SELECT agency, SUM(num_debris) as total_debris FROM satellites GROUP BY agency ORDER BY total_debris DESC;", "sql_context": "CREATE TABLE satellites (agency VARCHAR(50), num_debris INT); INSERT INTO satellites (agency, num_debris) VALUES ('NASA', 3000), ('ESA', 1000), ('Roscosmos', 2000);"} 93 | {"sql_prompt": "Count the number of posts that have more than 100 likes in the 'social_media' table.", "sql": "SELECT COUNT(*) FROM social_media WHERE likes > 100;", "sql_context": "CREATE TABLE social_media (user_id INT, post_id INT, post_date DATE, likes INT);"} 94 | {"sql_prompt": "Which directors have directed more than 10 movies, and what is the average rating of their films?", "sql": "SELECT d.name, AVG(m.rating) FROM movies m JOIN directors d ON m.director = d.name GROUP BY d.name HAVING COUNT(m.id) > 10;", "sql_context": "CREATE TABLE movies (id INT, title VARCHAR(255), release_year INT, rating FLOAT, genre VARCHAR(255), director VARCHAR(255)); CREATE TABLE directors (id INT, name VARCHAR(255), age INT, gender VARCHAR(10));"} 95 | {"sql_prompt": "What is the minimum sales revenue for a specific drug in a certain year?", "sql": "SELECT MIN(sales.revenue) FROM sales JOIN drugs ON sales.drug_id = drugs.id WHERE drugs.name = 'DrugA' AND sales.year = 2020;", "sql_context": "CREATE TABLE drugs (id INT, name VARCHAR(255)); INSERT INTO drugs (id, name) VALUES (1, 'DrugA'), (2, 'DrugB'); CREATE TABLE sales (id INT, drug_id INT, year INT, revenue INT);"} 96 | {"sql_prompt": "What is the average playtime for female players in RPG games?", "sql": "SELECT AVG(Playtime) as AvgPlaytime FROM PlayerActivity INNER JOIN PlayerInfo ON PlayerActivity.PlayerID = PlayerInfo.PlayerID WHERE Gender = 'Female' AND GameType = 'RPG';", "sql_context": "CREATE TABLE PlayerInfo (PlayerID INT, Gender VARCHAR(50), GameType VARCHAR(50)); INSERT INTO PlayerInfo (PlayerID, Gender, GameType) VALUES (1, 'Female', 'RPG'), (2, 'Male', 'FPS'), (3, 'Female', 'RPG'), (4, 'Non-binary', 'Simulation'); CREATE TABLE PlayerActivity (PlayerID INT, GameID INT, Playtime FLOAT); INSERT INTO PlayerActivity (PlayerID, GameID, Playtime) VALUES (1, 1, 50.5), (2, 2, 60.2), (3, 1, 75.1), (4, 3, 80.5);"} 97 | {"sql_prompt": "What is the total landfill capacity in cubic meters for each region in Africa in 2021?", "sql": "SELECT region, SUM(capacity) FROM landfill_capacity WHERE year = 2021 GROUP BY region;", "sql_context": "CREATE TABLE landfill_capacity (region VARCHAR(50), year INT, capacity FLOAT); INSERT INTO landfill_capacity (region, year, capacity) VALUES ('Northern Africa', 2021, 1200000.0), ('Western Africa', 2021, 1500000.0), ('Eastern Africa', 2021, 2000000.0), ('Central Africa', 2021, 1800000.0), ('Southern Africa', 2021, 1300000.0);"} 98 | {"sql_prompt": "What is the minimum cultural competency score of community health workers in each city in India?", "sql": "SELECT City, MIN(Score) as MinScore FROM CommunityHealthWorkers GROUP BY City;", "sql_context": "CREATE TABLE CommunityHealthWorkers (WorkerID INT, City VARCHAR(255), Score INT); INSERT INTO CommunityHealthWorkers (WorkerID, City, Score) VALUES (1, 'Mumbai', 80), (2, 'Delhi', 85), (3, 'Bangalore', 90), (4, 'Hyderabad', 70);"} 99 | {"sql_prompt": "What are the top 3 countries with the highest energy efficiency ratings?", "sql": "SELECT country, rating FROM energy_efficiency ORDER BY rating DESC LIMIT 3;", "sql_context": "CREATE TABLE energy_efficiency (country VARCHAR(50), rating FLOAT); INSERT INTO energy_efficiency (country, rating) VALUES ('SE', 3.4), ('NO', 3.3), ('FI', 3.1);"} 100 | {"sql_prompt": "What is the maximum budget for any rural infrastructure project in Colombia?", "sql": "SELECT MAX(budget) FROM rural_infrastructure WHERE location = 'Colombia';", "sql_context": "CREATE TABLE rural_infrastructure (id INT, name TEXT, location TEXT, budget FLOAT); INSERT INTO rural_infrastructure (id, name, location, budget) VALUES (1, 'Solar Power Plant', 'Colombia', 800000.00), (2, 'Irrigation System', 'Colombia', 500000.00), (3, 'Healthcare Center', 'Ecuador', 600000.00);"} 101 | -------------------------------------------------------------------------------- /eval_grpo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import time 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from collections import Counter 9 | from openai import OpenAI 10 | from unsloth import FastLanguageModel 11 | from peft import PeftModel 12 | 13 | EVAL_FILE = "cleaned_eval_queries.jsonl" 14 | NUM_SAMPLES = 50 15 | OPENAI_API_KEY = '' 16 | MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" 17 | FINETUNED_PATH = "outputs/sql_grpo/final_lora" 18 | MAX_SEQ_LENGTH = 1024 19 | RESULT_FILE = "evaluation_results.json" 20 | CONCURRENT_REQUESTS = 5 21 | 22 | SYSTEM_PROMPT = """ 23 | You are an AI assistant that converts natural language questions into SQL queries compatible with PostgreSQL syntax. 24 | Given a database schema and a question, generate the correct PostgreSQL query. 25 | 26 | Here's an example of how you should respond: 27 | 28 | 29 | This database has a users table with columns for id, name, and age. 30 | The question asks for all users over 30, so I need to query the users table with a WHERE condition. 31 | 32 | 33 | SELECT * FROM users WHERE age > 30; 34 | 35 | 36 | Respond ONLY in the format above, including the and tags. 37 | """ 38 | 39 | 40 | def extract_sql(text): 41 | match = re.search(r"(.*?)", text, re.IGNORECASE | re.DOTALL) 42 | return match.group(1).strip() if match else "" 43 | 44 | 45 | def extract_reasoning(text): 46 | match = re.search(r"(.*?)", 47 | text, re.IGNORECASE | re.DOTALL) 48 | return match.group(1).strip() if match else "" 49 | 50 | 51 | def parse_gpt_evaluation(evaluation_text): 52 | result = { 53 | "SQL_SCORE": 0, 54 | "REASONING_SCORE": 0, 55 | "FORMAT_SCORE": 0, 56 | "EDUCATIONAL_SCORE": 0, 57 | "OVERALL_SCORE": 0, 58 | "EXPLANATION": "", 59 | "ERROR_TYPE": "unknown" 60 | } 61 | 62 | sql_score = re.search(r"SQL_SCORE:\s*(\d+)", evaluation_text) 63 | reasoning_score = re.search(r"REASONING_SCORE:\s*(\d+)", evaluation_text) 64 | format_score = re.search(r"FORMAT_SCORE:\s*(\d+)", evaluation_text) 65 | educational_score = re.search( 66 | r"EDUCATIONAL_SCORE:\s*(\d+)", evaluation_text) 67 | overall_score = re.search(r"OVERALL_SCORE:\s*(\d+\.?\d*)", evaluation_text) 68 | error_type = re.search(r"ERROR_TYPE:\s*([^\n]+)", evaluation_text) 69 | 70 | explanation_match = re.search( 71 | r"EXPLANATION:\s*(.*?)(?=ERROR_TYPE:|$)", evaluation_text, re.DOTALL) 72 | 73 | if sql_score: 74 | result["SQL_SCORE"] = int(sql_score.group(1)) 75 | if reasoning_score: 76 | result["REASONING_SCORE"] = int(reasoning_score.group(1)) 77 | if format_score: 78 | result["FORMAT_SCORE"] = int(format_score.group(1)) 79 | if educational_score: 80 | result["EDUCATIONAL_SCORE"] = int(educational_score.group(1)) 81 | if overall_score: 82 | result["OVERALL_SCORE"] = float(overall_score.group(1)) 83 | if explanation_match: 84 | result["EXPLANATION"] = explanation_match.group(1).strip() 85 | if error_type: 86 | result["ERROR_TYPE"] = error_type.group(1).strip() 87 | 88 | return result 89 | 90 | 91 | def load_model(): 92 | print("Loading model...") 93 | try: 94 | model, tokenizer = FastLanguageModel.from_pretrained( 95 | model_name=MODEL_NAME, 96 | max_seq_length=MAX_SEQ_LENGTH, 97 | load_in_4bit=True, 98 | device_map="auto", 99 | ) 100 | 101 | model = PeftModel.from_pretrained(model, FINETUNED_PATH) 102 | model.eval() 103 | 104 | print(f"Model loaded successfully from {FINETUNED_PATH}") 105 | return model, tokenizer 106 | except Exception as e: 107 | print(f"Failed to load model: {e}") 108 | raise 109 | 110 | 111 | def generate_response(model, tokenizer, prompt): 112 | try: 113 | prompt_chat = [ 114 | {"role": "system", "content": SYSTEM_PROMPT}, 115 | {"role": "user", "content": prompt} 116 | ] 117 | text = tokenizer.apply_chat_template( 118 | prompt_chat, tokenize=False, add_generation_prompt=True) 119 | 120 | inputs = tokenizer(text, return_tensors="pt", truncation=True, 121 | max_length=MAX_SEQ_LENGTH).to(model.device) 122 | 123 | with torch.no_grad(): 124 | output_ids = model.generate( 125 | **inputs, 126 | max_new_tokens=500, 127 | temperature=0.1, 128 | top_p=0.95, 129 | do_sample=True, 130 | pad_token_id=tokenizer.eos_token_id 131 | ) 132 | 133 | input_length = inputs['input_ids'].shape[1] 134 | generated_tokens = output_ids[0][input_length:] 135 | output_text = tokenizer.decode( 136 | generated_tokens, skip_special_tokens=True) 137 | 138 | return output_text 139 | except Exception as e: 140 | print(f"Error generating response: {e}") 141 | return "[Error generating response]" 142 | 143 | 144 | def evaluate_with_gpt4o_mini(samples, client): 145 | results = [] 146 | 147 | for sample in tqdm(samples, desc="Evaluating with GPT-4o-mini"): 148 | eval_prompt = f""" 149 | As an SQL expert, evaluate this text-to-SQL conversion. Score each dimension from 1-5 (1=poor, 5=excellent). 150 | 151 | DATABASE SCHEMA: 152 | {sample['sql_context']} 153 | 154 | QUESTION: 155 | {sample['sql_prompt']} 156 | 157 | GOLD SQL (CORRECT): 158 | {sample['sql']} 159 | 160 | MODEL OUTPUT: 161 | {sample['model_output']} 162 | 163 | Provide scores in this exact format: 164 | SQL_SCORE: [1-5] - Does the SQL work and produce correct results? 165 | REASONING_SCORE: [1-5] - Is the reasoning clear, logical, and references correct schema? 166 | FORMAT_SCORE: [1-5] - Does it follow ...... format? 167 | EDUCATIONAL_SCORE: [1-5] - Would this help someone learn SQL? 168 | OVERALL_SCORE: [average] 169 | EXPLANATION: [brief explanation of strengths/weaknesses] 170 | ERROR_TYPE: [none/syntax/logic/format/other] 171 | """ 172 | 173 | try: 174 | response = client.chat.completions.create( 175 | model="gpt-4o-mini", 176 | messages=[{"role": "user", "content": eval_prompt}], 177 | temperature=0.1 178 | ) 179 | 180 | evaluation = response.choices[0].message.content 181 | 182 | eval_result = parse_gpt_evaluation(evaluation) 183 | 184 | result = { 185 | "sample_id": sample.get("id", len(results)), 186 | "question": sample["sql_prompt"], 187 | "sql_context": sample["sql_context"], 188 | "gold_sql": sample["sql"], 189 | "model_output": sample["model_output"], 190 | "extracted_sql": sample["extracted_sql"], 191 | "extracted_reasoning": sample["extracted_reasoning"], 192 | "evaluation": eval_result, 193 | "raw_evaluation": evaluation 194 | } 195 | 196 | results.append(result) 197 | 198 | time.sleep(0.5) 199 | 200 | except Exception as e: 201 | print(f"Error evaluating sample: {e}") 202 | results.append({ 203 | "sample_id": sample.get("id", len(results)), 204 | "question": sample["sql_prompt"], 205 | "sql_context": sample["sql_context"], 206 | "gold_sql": sample["sql"], 207 | "model_output": sample["model_output"], 208 | "extracted_sql": sample["extracted_sql"], 209 | "extracted_reasoning": sample["extracted_reasoning"], 210 | "evaluation": {"ERROR": str(e)}, 211 | "evaluation_failed": True 212 | }) 213 | 214 | return results 215 | 216 | 217 | def format_results_summary(evaluation_results): 218 | valid_results = [r for r in evaluation_results if not r.get( 219 | "evaluation_failed", False)] 220 | 221 | if not valid_results: 222 | return "No valid evaluation results." 223 | 224 | scores = { 225 | "SQL_SCORE": [r["evaluation"]["SQL_SCORE"] for r in valid_results], 226 | "REASONING_SCORE": [r["evaluation"]["REASONING_SCORE"] for r in valid_results], 227 | "FORMAT_SCORE": [r["evaluation"]["FORMAT_SCORE"] for r in valid_results], 228 | "EDUCATIONAL_SCORE": [r["evaluation"]["EDUCATIONAL_SCORE"] for r in valid_results], 229 | "OVERALL_SCORE": [r["evaluation"]["OVERALL_SCORE"] for r in valid_results] 230 | } 231 | 232 | error_types = Counter([r["evaluation"]["ERROR_TYPE"] 233 | for r in valid_results]) 234 | 235 | summary = "=== EVALUATION SUMMARY ===\n\n" 236 | 237 | summary += "AVERAGE SCORES:\n" 238 | for metric, values in scores.items(): 239 | summary += f" {metric}: {np.mean(values):.2f} (±{np.std(values):.2f})\n" 240 | 241 | summary += "\nSCORE DISTRIBUTION:\n" 242 | for metric, values in scores.items(): 243 | counts = Counter(values) 244 | summary += f" {metric}: " + " | ".join( 245 | [f"{score}={count}" for score, count in sorted(counts.items())]) + "\n" 246 | 247 | summary += "\nERROR TYPES:\n" 248 | for error_type, count in error_types.most_common(): 249 | summary += f" {error_type}: {count} ({count/len(valid_results)*100:.1f}%)\n" 250 | 251 | summary += f"\nTotal samples evaluated: {len(evaluation_results)}\n" 252 | summary += f"Valid evaluations: {len(valid_results)}\n" 253 | 254 | return summary 255 | 256 | 257 | def main(): 258 | if not OPENAI_API_KEY: 259 | print("OPENAI_API_KEY environment variable not set. Please set it before running this script.") 260 | return 261 | 262 | client = OpenAI(api_key=OPENAI_API_KEY) 263 | 264 | model, tokenizer = load_model() 265 | 266 | print(f"Loading evaluation data from {EVAL_FILE}...") 267 | try: 268 | with open(EVAL_FILE, 'r') as f: 269 | eval_data = [json.loads(line) for line in f] 270 | 271 | eval_subset = eval_data[:NUM_SAMPLES] 272 | print(f"Loaded {len(eval_subset)} samples for evaluation") 273 | except Exception as e: 274 | print(f"Error loading evaluation data: {e}") 275 | return 276 | 277 | print("Generating model outputs...") 278 | for sample in tqdm(eval_subset, desc="Generating responses"): 279 | prompt = f"Database schema:\n{sample['sql_context']}\n\nQuestion: {sample['sql_prompt']}" 280 | sample["model_output"] = generate_response(model, tokenizer, prompt) 281 | sample["extracted_sql"] = extract_sql(sample["model_output"]) 282 | sample["extracted_reasoning"] = extract_reasoning( 283 | sample["model_output"]) 284 | 285 | print("Evaluating with GPT-4o-mini...") 286 | evaluation_results = evaluate_with_gpt4o_mini(eval_subset, client) 287 | 288 | print(f"Saving results to {RESULT_FILE}...") 289 | with open(RESULT_FILE, 'w') as f: 290 | json.dump(evaluation_results, f, indent=2) 291 | 292 | csv_data = [] 293 | for result in evaluation_results: 294 | if result.get("evaluation_failed", False): 295 | continue 296 | 297 | csv_data.append({ 298 | "id": result["sample_id"], 299 | "question": result["question"], 300 | "sql_score": result["evaluation"]["SQL_SCORE"], 301 | "reasoning_score": result["evaluation"]["REASONING_SCORE"], 302 | "format_score": result["evaluation"]["FORMAT_SCORE"], 303 | "educational_score": result["evaluation"]["EDUCATIONAL_SCORE"], 304 | "overall_score": result["evaluation"]["OVERALL_SCORE"], 305 | "error_type": result["evaluation"]["ERROR_TYPE"] 306 | }) 307 | 308 | pd.DataFrame(csv_data).to_csv("evaluation_summary.csv", index=False) 309 | 310 | summary = format_results_summary(evaluation_results) 311 | print(summary) 312 | 313 | with open("evaluation_summary.txt", 'w') as f: 314 | f.write(summary) 315 | 316 | print( 317 | f"Evaluation complete. Results saved to {RESULT_FILE}, evaluation_summary.csv, and evaluation_summary.txt") 318 | 319 | 320 | if __name__ == "__main__": 321 | main() 322 | -------------------------------------------------------------------------------- /llm_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sqlite3 5 | import tempfile 6 | from tqdm import tqdm 7 | import logging 8 | 9 | import pandas as pd 10 | import sqlparse 11 | import torch 12 | from datasets import load_dataset 13 | from unsloth import FastLanguageModel, is_bfloat16_supported 14 | from trl import GRPOConfig, GRPOTrainer 15 | from transformers import TrainerCallback 16 | from sql_reward_utils import ( 17 | soft_format_reward_func, 18 | strict_format_reward_func, 19 | execute_query_reward_func, 20 | complexity_reward, 21 | reasoning_quality_reward, 22 | REWARD_WEIGHTS 23 | ) 24 | 25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 26 | 27 | logging.basicConfig(level=logging.INFO, 28 | format='%(asctime)s - %(levelname)s - %(message)s') 29 | logger = logging.getLogger(__name__) 30 | 31 | TRAIN_DATA_FILE = "cleaned_train_queries.jsonl" 32 | EVAL_DATA_FILE = "cleaned_eval_queries.jsonl" 33 | 34 | OUTPUT_DIR = "outputs/sql_grpo" 35 | os.makedirs(OUTPUT_DIR, exist_ok=True) 36 | 37 | 38 | def print_gpu_memory(step=""): 39 | if torch.cuda.is_available(): 40 | allocated = torch.cuda.memory_allocated(0) / 1e9 41 | reserved = torch.cuda.memory_reserved(0) / 1e9 42 | max_allocated = torch.cuda.max_memory_allocated(0) / 1e9 43 | print(f"\n--- GPU Memory at {step} ---") 44 | print(f"Memory allocated: {allocated:.2f} GB") 45 | print(f"Memory reserved: {reserved:.2f} GB") 46 | print(f"Max memory allocated: {max_allocated:.2f} GB") 47 | 48 | 49 | MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" 50 | MAX_SEQ_LENGTH = 1024 51 | LORA_RANK = 32 52 | BATCH_SIZE = 4 53 | GRAD_ACCUMULATION = 2 54 | NUM_GENERATIONS = 8 55 | MAX_STEPS = 250 56 | USE_WANDB = True 57 | 58 | DATASET_NAME = "gretelai/synthetic_text_to_sql" 59 | NUM_EXAMPLES = 300 60 | DATASET_SPLIT = "train" 61 | 62 | REWARD_WEIGHTS = { 63 | "format": 1.0, 64 | "sql_correctness": 1.2, 65 | "complexity": 0.6, 66 | "reasoning": 0.7, 67 | } 68 | SYNTAX_PENALTY = -0.1 * REWARD_WEIGHTS["sql_correctness"] 69 | 70 | if USE_WANDB: 71 | try: 72 | import wandb 73 | except ImportError: 74 | print("Wandb not installed. Disabling W&B logging.") 75 | USE_WANDB = False 76 | 77 | print("\n=== Starting SQL-to-Text Training Script ===") 78 | print(f"Model: {MODEL_NAME}") 79 | print(f"Max Sequence Length: {MAX_SEQ_LENGTH}") 80 | print(f"LoRA Rank: {LORA_RANK}") 81 | print(f"Batch Size: {BATCH_SIZE}") 82 | print(f"Gradient Accumulation: {GRAD_ACCUMULATION}") 83 | print(f"Number of Generations: {NUM_GENERATIONS}") 84 | print(f"Max Steps: {MAX_STEPS}") 85 | print_gpu_memory("start") 86 | 87 | SYSTEM_PROMPT = """ 88 | You are an AI assistant that converts natural language questions into SQL queries compatible with PostgreSQL syntax. 89 | Given a database schema and a question, generate the correct PostgreSQL query. 90 | 91 | Think about the problem and provide your working out. 92 | Place it between and . 93 | Then, provide your solution between and . 94 | 95 | Here's an example of how you should respond: 96 | 97 | 98 | This database has a users table with columns for id, name, and age. 99 | The question asks for all users over 30, so I need to query the users table with a WHERE condition. 100 | 101 | 102 | SELECT * FROM users WHERE age > 30; 103 | 104 | 105 | Respond ONLY in the format above, including the and tags. 106 | """ 107 | 108 | 109 | def extract_sql(text: str) -> str: 110 | if not text: 111 | return "" 112 | 113 | match = re.search(r"(.*?)", text, re.IGNORECASE | re.DOTALL) 114 | if match: 115 | sql = match.group(1).strip() 116 | sql = re.sub(r"^\s*--.*?\n", "", sql) 117 | sql = re.sub(r"\n--.*?\s*$", "", sql) 118 | return sql.strip() 119 | else: 120 | sql_keywords = ["SELECT ", "INSERT ", "UPDATE ", "DELETE ", 121 | "CREATE ", "ALTER ", "DROP ", "TRUNCATE ", 122 | "GRANT ", "REVOKE ", "MERGE ", "EXEC ", "WITH "] 123 | 124 | text_upper = text.upper() 125 | sql_start_index = -1 126 | keyword_found = "" 127 | 128 | for keyword in sql_keywords: 129 | idx = text_upper.find(keyword) 130 | if idx != -1: 131 | if sql_start_index == -1 or idx < sql_start_index: 132 | sql_start_index = idx 133 | keyword_found = keyword 134 | 135 | if sql_start_index != -1: 136 | potential_sql = text[sql_start_index:] 137 | if "" in potential_sql: 138 | potential_sql = potential_sql.split("", 1)[0] 139 | 140 | if ";" in potential_sql: 141 | potential_sql = potential_sql.split(";", 1)[0] + ";" 142 | return potential_sql.strip() 143 | 144 | return "" 145 | 146 | 147 | def extract_schema_from_context(sql_context: str) -> str: 148 | if not sql_context: 149 | return "No schema information available." 150 | statements = sqlparse.split(sql_context) 151 | schema_statements = [ 152 | s.strip() for s in statements 153 | if s.strip().upper().startswith("CREATE TABLE") 154 | ] 155 | extracted_schema = "\n".join(schema_statements) 156 | return extracted_schema if extracted_schema else sql_context 157 | 158 | 159 | def filter_sql_context_for_training(sql_context: str) -> str: 160 | return extract_schema_from_context(sql_context) 161 | 162 | 163 | try: 164 | logger.info("=== Loading Model ===") 165 | model, tokenizer = FastLanguageModel.from_pretrained( 166 | model_name=MODEL_NAME, 167 | max_seq_length=MAX_SEQ_LENGTH, 168 | load_in_4bit=True, 169 | fast_inference=True, 170 | max_lora_rank=LORA_RANK, 171 | dtype=None, 172 | ) 173 | logger.info("Model loaded successfully") 174 | print_gpu_memory("after model load") 175 | 176 | logger.info("=== Applying LoRA ===") 177 | model = FastLanguageModel.get_peft_model( 178 | model, 179 | r=LORA_RANK, 180 | target_modules=[ 181 | "q_proj", "k_proj", "v_proj", "o_proj", 182 | "gate_proj", "up_proj", "down_proj", 183 | ], 184 | lora_alpha=LORA_RANK, 185 | lora_dropout=0.05, # 0.1 186 | bias="none", 187 | use_gradient_checkpointing="unsloth", 188 | random_state=3407, 189 | max_seq_length=MAX_SEQ_LENGTH, 190 | ) 191 | logger.info("LoRA adapters applied successfully") 192 | print_gpu_memory("after LoRA") 193 | except Exception as e: 194 | logger.error( 195 | f"Error in model loading or LoRA application: {e}", exc_info=True) 196 | exit(1) 197 | 198 | try: 199 | logger.info("=== Loading Dataset ===") 200 | train_df = pd.read_json(TRAIN_DATA_FILE, lines=True) 201 | 202 | if NUM_EXAMPLES and len(train_df) > NUM_EXAMPLES: 203 | dataset = train_df.sample( 204 | n=NUM_EXAMPLES, random_state=42).reset_index(drop=True) 205 | 206 | dataset = train_df.to_dict(orient='records') 207 | 208 | logger.info(f"Loaded {len(dataset)} examples from {DATASET_NAME}") 209 | print_gpu_memory("after dataset load") 210 | 211 | train_data = [] 212 | logger.info("=== Preparing Dataset ===") 213 | 214 | for i, example in enumerate(tqdm(dataset, desc="Processing examples")): 215 | sql_prompt = example.get("sql_prompt", "") 216 | sql_context = example.get("sql_context", "") 217 | gold_sql = example.get("sql", "") 218 | 219 | if not sql_prompt or not sql_context or not gold_sql: 220 | logger.warning( 221 | f"Skipping example {i} due to missing data (prompt, context, or gold SQL).") 222 | continue 223 | 224 | filtered_context = filter_sql_context_for_training(sql_context) 225 | schema_for_prompt = extract_schema_from_context(filtered_context) 226 | 227 | prompt_chat = [ 228 | {'role': 'system', 'content': SYSTEM_PROMPT}, 229 | {'role': 'user', 'content': f"{schema_for_prompt}\n\nQuestion: {sql_prompt}"} 230 | ] 231 | prompt_string = tokenizer.apply_chat_template( 232 | prompt_chat, tokenize=False, add_generation_prompt=True 233 | ) 234 | 235 | train_data.append({ 236 | 'prompt': prompt_string, 237 | 'references': [{ 238 | 'gold_sql': gold_sql, 239 | 'sql_context': sql_context, 240 | 'sql_prompt': sql_prompt 241 | }], 242 | }) 243 | 244 | logger.info(f"Prepared {len(train_data)} training examples") 245 | print_gpu_memory("after data preparation") 246 | 247 | if not train_data: 248 | logger.error( 249 | "No valid training data could be prepared. Check dataset format and content.") 250 | exit(1) 251 | 252 | except Exception as e: 253 | logger.error(f"Error in data preparation: {e}", exc_info=True) 254 | exit(1) 255 | 256 | 257 | class RewardLoggerCallback(TrainerCallback): 258 | def __init__(self): 259 | self.step = 0 260 | 261 | def on_step_end(self, args, state, control, **kwargs): 262 | self.step += 1 263 | if self.step % 25 == 0: 264 | logger.info(f"\n--- Step {self.step} Reward Details (Sample) ---") 265 | if 'loss' in state.log_history[-1]: 266 | logger.info( 267 | f" Step {self.step}: Current Loss: {state.log_history[-1]['loss']:.4f}") 268 | 269 | 270 | def train_model(): 271 | if USE_WANDB: 272 | try: 273 | if wandb.run is None: 274 | wandb.init( 275 | project="text-to-sql-finetuning", 276 | name=f"sql-grpo-{MODEL_NAME.split('/')[-1]}-{MAX_STEPS}steps", 277 | config={ 278 | "model_name": MODEL_NAME, 279 | "lora_rank": LORA_RANK, 280 | "max_seq_length": MAX_SEQ_LENGTH, 281 | "batch_size": BATCH_SIZE, 282 | "grad_accumulation": GRAD_ACCUMULATION, 283 | "num_generations": NUM_GENERATIONS, 284 | "max_steps": MAX_STEPS, 285 | "dataset": DATASET_NAME, 286 | "num_examples": NUM_EXAMPLES, 287 | "learning_rate": 5e-6, 288 | "weight_decay": 0.01, 289 | "warmup_ratio": 0.1, 290 | "lr_scheduler_type": "cosine", 291 | "optim": "adamw_8bit", 292 | "syntax_penalty": SYNTAX_PENALTY, 293 | "reward_weights": REWARD_WEIGHTS, 294 | "stage": "grpo", 295 | }, 296 | resume="allow", 297 | save_code=True, 298 | ) 299 | else: 300 | logger.info("WandB already initialized, resuming run.") 301 | except Exception as e: 302 | logger.error(f"WandB initialization failed: {e}", exc_info=True) 303 | 304 | torch.cuda.empty_cache() 305 | print_gpu_memory("before trainer init") 306 | 307 | effective_max_completion_length = 300 308 | effective_max_prompt_length = MAX_SEQ_LENGTH - \ 309 | effective_max_completion_length - 32 310 | 311 | training_args = GRPOConfig( 312 | output_dir=OUTPUT_DIR, 313 | learning_rate=5e-6, 314 | per_device_train_batch_size=BATCH_SIZE, 315 | gradient_accumulation_steps=GRAD_ACCUMULATION, 316 | optim="adamw_8bit", 317 | max_steps=MAX_STEPS, 318 | warmup_ratio=0.1, 319 | lr_scheduler_type="cosine", 320 | logging_steps=5, 321 | save_steps=50, 322 | save_total_limit=2, 323 | save_strategy="steps", 324 | bf16=is_bfloat16_supported(), 325 | fp16=not is_bfloat16_supported(), 326 | gradient_checkpointing=True, 327 | gradient_checkpointing_kwargs={"use_reentrant": False}, 328 | max_prompt_length=effective_max_prompt_length, 329 | max_completion_length=effective_max_completion_length, 330 | num_generations=NUM_GENERATIONS, 331 | beta=0.1, 332 | use_vllm=True, 333 | report_to="wandb" if USE_WANDB else "none", 334 | remove_unused_columns=False, 335 | seed=42, 336 | dataloader_num_workers=2, 337 | max_grad_norm=1.0, 338 | ) 339 | 340 | logger.info("Initializing GRPOTrainer with improved reward functions...") 341 | trainer = GRPOTrainer( 342 | model=model, 343 | beta=training_args.beta, 344 | processing_class=tokenizer, 345 | args=training_args, 346 | train_dataset=train_data, 347 | reward_funcs=[ 348 | soft_format_reward_func, 349 | execute_query_reward_func, 350 | complexity_reward, 351 | reasoning_quality_reward, 352 | ], 353 | callbacks=[RewardLoggerCallback()] if not USE_WANDB else None, 354 | ) 355 | 356 | torch.cuda.empty_cache() 357 | print_gpu_memory("before training starts") 358 | 359 | logger.info("Starting GRPO training...") 360 | try: 361 | trainer.train() 362 | except Exception as e: 363 | logger.error(f"Training failed: {e}", exc_info=True) 364 | raise 365 | 366 | final_save_path = f"{OUTPUT_DIR}/final_lora" 367 | logger.info(f"Saving final LoRA adapters to {final_save_path}...") 368 | trainer.save_model(final_save_path) 369 | tokenizer.save_pretrained(final_save_path) 370 | logger.info("Model and tokenizer saved.") 371 | 372 | if USE_WANDB and wandb.run: 373 | try: 374 | logger.info("Logging final model artifacts to WandB...") 375 | wandb.save(f"{final_save_path}/*") 376 | wandb.finish() 377 | logger.info("WandB run finished.") 378 | except Exception as e: 379 | logger.error( 380 | f"Failed to finish WandB run or save artifacts: {e}", exc_info=True) 381 | 382 | print_gpu_memory("after training") 383 | return model, tokenizer 384 | 385 | 386 | def test_model(model, tokenizer): 387 | logger.info("\n=== Testing trained model with a sample query ===") 388 | 389 | EVAL_DATA_FILE = "cleaned_eval_queries.jsonl" 390 | 391 | try: 392 | eval_df = pd.read_json(EVAL_DATA_FILE, lines=True) 393 | if eval_df.empty: 394 | raise ValueError( 395 | f"Evaluation dataset '{EVAL_DATA_FILE}' is empty.") 396 | 397 | eval_sample = eval_df.sample(n=1, random_state=123).iloc[0] 398 | 399 | sql_prompt = eval_sample.get("sql_prompt", "N/A") 400 | sql_context = eval_sample.get("sql_context", "") 401 | gold_sql = eval_sample.get("sql", "N/A") 402 | 403 | except (ValueError, FileNotFoundError) as e: 404 | logger.warning( 405 | f"Could not load eval sample: {e}. Using a default sample.") 406 | sql_prompt = "List the names of departments with more than 10 employees." 407 | sql_context = """ 408 | CREATE TABLE departments (department_id INT PRIMARY KEY, name TEXT); 409 | CREATE TABLE employees (employee_id INT PRIMARY KEY, name TEXT, department_id INT, FOREIGN KEY (department_id) REFERENCES departments(department_id)); 410 | """ 411 | gold_sql = """ 412 | SELECT T1.name FROM departments AS T1 JOIN employees AS T2 ON T1.department_id = T2.department_id GROUP BY T1.department_id HAVING count(*) > 10 413 | """ 414 | 415 | schema_for_prompt = extract_schema_from_context(sql_context) 416 | test_prompt_chat = [ 417 | {"role": "system", "content": SYSTEM_PROMPT}, 418 | {"role": "user", "content": f"{schema_for_prompt}\n\nQuestion: {sql_prompt}"} 419 | ] 420 | text = tokenizer.apply_chat_template( 421 | test_prompt_chat, tokenize=False, add_generation_prompt=True 422 | ) 423 | 424 | if torch.cuda.is_available(): 425 | model.cuda() 426 | model.eval() 427 | 428 | torch.cuda.empty_cache() 429 | print_gpu_memory("before test generation") 430 | 431 | logger.info("Generating test response...") 432 | output_text = "[Generation Failed]" 433 | try: 434 | with torch.no_grad(): 435 | inputs = tokenizer(text, return_tensors="pt", truncation=True, 436 | max_length=MAX_SEQ_LENGTH).to(model.device) 437 | 438 | output_ids = model.generate( 439 | **inputs, 440 | max_new_tokens=300, 441 | temperature=0.2, 442 | top_p=0.95, 443 | do_sample=True, 444 | pad_token_id=tokenizer.eos_token_id 445 | ) 446 | 447 | input_length = inputs['input_ids'].shape[1] 448 | generated_tokens = output_ids[0][input_length:] 449 | output_text = tokenizer.decode( 450 | generated_tokens, skip_special_tokens=True) 451 | 452 | except Exception as e: 453 | logger.error(f"Error during test generation: {e}", exc_info=True) 454 | 455 | print("\n--- Test Results ---") 456 | print(f"Question: {sql_prompt}") 457 | print("-" * 40) 458 | print(f"Gold SQL:\n{gold_sql}") 459 | print("-" * 40) 460 | generated_sql = extract_sql(output_text) 461 | print( 462 | f"Generated SQL:\n{generated_sql if generated_sql else '[No SQL Extracted]'}") 463 | print("-" * 40) 464 | print(f"Full Generated Output:\n{output_text}") 465 | print("-" * 40) 466 | 467 | print_gpu_memory("after test") 468 | 469 | 470 | if __name__ == "__main__": 471 | trained_model, trained_tokenizer = train_model() 472 | 473 | if trained_model and trained_tokenizer: 474 | test_model(trained_model, trained_tokenizer) 475 | else: 476 | logger.error( 477 | "Training did not return a valid model or tokenizer. Skipping test.") 478 | 479 | logger.info("\nGRPO Training Script Completed.") 480 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | unsloth 2 | vllm 3 | sqlparse 4 | wandb 5 | sqlglot 6 | datasets 7 | tqdm 8 | openai 9 | pandas -------------------------------------------------------------------------------- /sql_reward_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import tempfile 4 | import sqlite3 5 | import logging 6 | import sqlparse 7 | import torch 8 | import copy 9 | from sqlglot import parse, transpile, ParseError 10 | from sqlglot.expressions import Column, Table 11 | from typing import List, Dict, Tuple, Any, Optional, Set, Union 12 | 13 | log_level = logging.DEBUG if os.environ.get( 14 | "SQL_DEBUG_MODE") == "1" else logging.CRITICAL + 1 15 | logging.basicConfig( 16 | level=log_level, 17 | format='%(asctime)s - %(levelname)s - %(name)s - %(filename)s:%(lineno)d - %(message)s' 18 | ) 19 | logger = logging.getLogger(__name__) 20 | 21 | REWARD_WEIGHTS = { 22 | "format": 1.0, 23 | "sql_correctness": 1.2, 24 | "complexity": 0.6, 25 | "reasoning": 0.7, 26 | } 27 | 28 | DEBUG_MODE = os.environ.get("SQL_DEBUG_MODE") == "1" 29 | 30 | ERR_SYNTAX = "syntax_error" 31 | ERR_MISSING_TABLE = "missing_table" 32 | ERR_MISSING_COLUMN = "missing_column" 33 | ERR_AMBIGUOUS_COLUMN = "ambiguous_column" 34 | ERR_TYPE_MISMATCH = "type_mismatch" 35 | ERR_CONSTRAINT = "constraint_violation" 36 | ERR_FUNCTION = "function_error" 37 | ERR_RESOURCE = "resource_error" 38 | ERR_OTHER = "other_error" 39 | ERR_SCHEMA_SETUP = "schema_setup_error" 40 | ERR_CONVERSION = "sql_conversion_error" 41 | ERR_EXECUTION = "sql_execution_error" 42 | ERR_SCHEMA_VALIDATION = "schema_validation_error" 43 | 44 | 45 | def _get_response_text(completion: Any) -> str: 46 | response_text = "" 47 | if isinstance(completion, str): 48 | response_text = completion 49 | elif isinstance(completion, list) and completion: 50 | if isinstance(completion[0], dict): 51 | response_text = completion[0].get( 52 | 'content', completion[0].get('generated_text', '')) 53 | elif isinstance(completion[0], str): 54 | response_text = completion[0] 55 | elif isinstance(completion, dict): 56 | response_text = completion.get( 57 | 'content', completion.get('generated_text', '')) 58 | else: 59 | try: 60 | response_text = str(completion) 61 | except Exception: 62 | response_text = "" 63 | logger.debug( 64 | "Could not convert completion to string: %s", type(completion)) 65 | return response_text 66 | 67 | 68 | def extract_sql(text: str) -> str: 69 | if not text: 70 | return "" 71 | match = re.search(r"(.*?)", text, re.IGNORECASE | re.DOTALL) 72 | if match: 73 | sql = match.group(1).strip() 74 | sql = re.sub(r"^\s*--.*?\n", "", sql, flags=re.MULTILINE) 75 | sql = re.sub(r"\n--.*?\s*$", "", sql, flags=re.MULTILINE) 76 | sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) 77 | return sql.strip() 78 | else: 79 | sql_keywords = ["SELECT ", "INSERT ", "UPDATE ", 80 | "DELETE ", "CREATE ", "ALTER ", "DROP ", "WITH "] 81 | text_upper = text.upper() 82 | sql_start_index = -1 83 | for keyword in sql_keywords: 84 | idx = text_upper.find(keyword) 85 | if idx != -1 and (sql_start_index == -1 or idx < sql_start_index): 86 | sql_start_index = idx 87 | if sql_start_index != -1: 88 | potential_sql = text[sql_start_index:] 89 | potential_sql = potential_sql.split("", 1)[0] 90 | potential_sql = potential_sql.split("", 1)[0] 91 | if ";" in potential_sql: 92 | potential_sql = potential_sql.split(";", 1)[0] + ";" 93 | logger.debug("Extracted SQL using fallback method.") 94 | return potential_sql.strip() 95 | logger.debug( 96 | "Could not extract SQL using primary or fallback methods.") 97 | return "" 98 | 99 | 100 | def extract_reasoning(text: str) -> str: 101 | if not text: 102 | return "" 103 | match = re.search(r"(.*?)", 104 | text, re.IGNORECASE | re.DOTALL) 105 | return match.group(1).strip() if match else "" 106 | 107 | 108 | def calculate_sql_complexity(sql: str) -> float: 109 | if not sql: 110 | return 0.0 111 | try: 112 | sql_upper = sql.upper() 113 | score = 1.0 114 | score += sql_upper.count(" JOIN ") * 0.6 115 | score += sql_upper.count(" UNION ") * 0.8 + sql_upper.count( 116 | " INTERSECT ") * 0.8 + sql_upper.count(" EXCEPT ") * 0.8 117 | score += sql_upper.count("(SELECT") * 1.0 118 | score += sql_upper.count(" WITH ") * 0.8 119 | if " WHERE " in sql_upper: 120 | score += 0.2 121 | if " GROUP BY " in sql_upper: 122 | score += 0.5 123 | if " HAVING " in sql_upper: 124 | score += 0.7 125 | if " ORDER BY " in sql_upper: 126 | score += 0.3 127 | if " LIMIT " in sql_upper: 128 | score += 0.1 129 | agg_functions = ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("] 130 | score += sum(sql_upper.count(agg) for agg in agg_functions) * 0.3 131 | score += sql_upper.count(" DISTINCT ") * 0.3 132 | score += sql_upper.count(" CASE ") * 0.4 133 | score += sql_upper.count(" OVER(") * 1.0 134 | where_match = re.search( 135 | r" WHERE (.*?)(?: GROUP BY | ORDER BY | LIMIT | OFFSET |$)", sql_upper, re.DOTALL) 136 | if where_match: 137 | where_clause = where_match.group(1) 138 | score += where_clause.count(" AND ") * \ 139 | 0.15 + where_clause.count(" OR ") * 0.20 140 | score += where_clause.count(" IN ") * \ 141 | 0.2 + where_clause.count(" LIKE ") * 0.1 142 | score += where_clause.count(" BETWEEN ") * \ 143 | 0.2 + where_clause.count(" EXISTS ") * 0.3 144 | return max(0.0, score) 145 | except Exception as e: 146 | logger.warning( 147 | f"Error calculating complexity for '{sql[:50]}...': {e}") 148 | return 1.0 149 | 150 | 151 | def identify_sql_statement_type(sql: str) -> str: 152 | if not sql: 153 | return "UNKNOWN" 154 | clean_sql = re.sub(r'--.*?$', '', sql, flags=re.MULTILINE).strip() 155 | clean_sql = re.sub(r'/\*.*?\*/', '', clean_sql, flags=re.DOTALL).strip() 156 | if not clean_sql: 157 | return "UNKNOWN" 158 | first_word = clean_sql.split(None, 1)[0].upper() 159 | 160 | if first_word == "SELECT": 161 | return "SELECT" 162 | if first_word == "INSERT": 163 | return "INSERT" 164 | if first_word == "UPDATE": 165 | return "UPDATE" 166 | if first_word == "DELETE": 167 | return "DELETE" 168 | if first_word == "CREATE": 169 | if re.search(r"CREATE\s+(TABLE|VIEW|INDEX)", clean_sql[:30], re.IGNORECASE): 170 | second_word = clean_sql.split(None, 2)[1].upper() if len( 171 | clean_sql.split()) > 1 else "" 172 | if second_word == "TABLE": 173 | return "CREATE_TABLE" 174 | if second_word == "VIEW": 175 | return "CREATE_VIEW" 176 | if second_word == "INDEX": 177 | return "CREATE_INDEX" 178 | return "CREATE_OTHER" 179 | if first_word == "DROP": 180 | return "DROP" 181 | if first_word == "ALTER": 182 | return "ALTER" 183 | if first_word == "WITH": 184 | match = re.search(r'\)\s*(SELECT|INSERT|UPDATE|DELETE)', 185 | clean_sql, re.IGNORECASE | re.DOTALL) 186 | if match: 187 | return match.group(1).upper() 188 | return "WITH_UNKNOWN" 189 | return "OTHER" 190 | 191 | 192 | def list_all_tables(conn: sqlite3.Connection) -> List[str]: 193 | try: 194 | cursor = conn.cursor() 195 | cursor.execute( 196 | "SELECT name FROM sqlite_master WHERE type IN ('table', 'view');") 197 | return [row[0] for row in cursor.fetchall()] 198 | except sqlite3.Error as e: 199 | logger.error(f"Error listing tables/views: {e}") 200 | return [] 201 | 202 | 203 | def get_table_schema(conn: sqlite3.Connection, table_name: str) -> List[Tuple]: 204 | try: 205 | cursor = conn.cursor() 206 | cursor.execute(f"PRAGMA table_info('{table_name}');") 207 | return cursor.fetchall() 208 | except sqlite3.Error as e: 209 | logger.warning(f"Error getting schema for table {table_name}: {e}") 210 | return [] 211 | 212 | 213 | def check_table_exists(conn: sqlite3.Connection, table_name: str) -> Tuple[bool, bool, Optional[str]]: 214 | try: 215 | cursor = conn.cursor() 216 | cursor.execute( 217 | "SELECT name FROM sqlite_master WHERE type IN ('table', 'view') AND name=?;", (table_name,)) 218 | exact_match = cursor.fetchone() 219 | if exact_match: 220 | return True, True, table_name 221 | 222 | cursor.execute( 223 | "SELECT name FROM sqlite_master WHERE type IN ('table', 'view') AND lower(name)=lower(?);", (table_name,)) 224 | insensitive_match = cursor.fetchone() 225 | if insensitive_match: 226 | return False, True, insensitive_match[0] 227 | 228 | return False, False, None 229 | except sqlite3.Error as e: 230 | logger.warning( 231 | f"Error checking existence for table/view {table_name}: {e}") 232 | return False, False, None 233 | 234 | 235 | def get_column_names(conn: sqlite3.Connection, table_name: str) -> List[str]: 236 | schema = get_table_schema(conn, table_name) 237 | return [col[1] for col in schema] 238 | 239 | 240 | def extract_tables_from_query(sql: str) -> Set[str]: 241 | tables = set() 242 | if not sql: 243 | return tables 244 | try: 245 | parsed_expression = parse(sql, read="sqlite") 246 | 247 | if isinstance(parsed_expression, list): 248 | for expr in parsed_expression: 249 | if hasattr(expr, 'find_all'): 250 | for node in expr.find_all(): 251 | if hasattr(node, 'name') and hasattr(node, 'is_table') and node.is_table: 252 | table_name = node.name 253 | if table_name: 254 | tables.add(table_name) 255 | else: 256 | for node in parsed_expression.find_all(): 257 | if hasattr(node, 'name') and hasattr(node, 'is_table') and node.is_table: 258 | table_name = node.name 259 | if table_name: 260 | tables.add(table_name) 261 | except ParseError as e: 262 | logger.warning( 263 | f"sqlglot failed to parse for table extraction: {e}. Falling back to regex.") 264 | pattern = r'(?:FROM|JOIN)\s+([`"\[]?\w+[`"\]]?)(?:\s+(?:AS\s+)?(\w+))?' 265 | for match in re.finditer(pattern, sql, re.IGNORECASE): 266 | table = match.group(1).strip('`"[]') 267 | tables.add(table) 268 | except Exception as e: 269 | logger.error(f"Unexpected error during table extraction: {e}") 270 | 271 | return tables 272 | 273 | 274 | def convert_sql_to_sqlite(sql: str, source_dialect: str = "mysql") -> Optional[str]: 275 | if not sql or not sql.strip(): 276 | return sql 277 | try: 278 | if DEBUG_MODE: 279 | logger.debug( 280 | f"Converting SQL from {source_dialect} to sqlite: {sql[:150]}...") 281 | if source_dialect == "postgresql": 282 | try: 283 | converted = transpile(sql, read="postgres", write="sqlite") 284 | except ParseError as pg_err: 285 | logger.warning( 286 | f"PostgreSQL parse error: {pg_err}, trying fallback conversion...") 287 | modified_sql = sql 288 | modified_sql = re.sub( 289 | r'(\w+)::\w+', r'CAST(\1 AS TEXT)', modified_sql) 290 | modified_sql = re.sub( 291 | r'\s+RETURNING\s+.*?$', '', modified_sql, flags=re.IGNORECASE) 292 | try: 293 | converted = transpile( 294 | modified_sql, read="postgres", write="sqlite") 295 | except ParseError: 296 | logger.warning("Falling back to generic SQL parsing...") 297 | converted = transpile( 298 | modified_sql, read="generic", write="sqlite") 299 | else: 300 | converted = transpile(sql, read=source_dialect, write="sqlite") 301 | if converted and isinstance(converted, list) and converted[0]: 302 | if DEBUG_MODE: 303 | logger.debug(f"Converted SQL: {converted[0][:150]}...") 304 | return converted[0] 305 | else: 306 | logger.warning( 307 | f"sqlglot transpile returned empty result for: {sql[:100]}...") 308 | return sql 309 | except ParseError as e: 310 | logger.warning( 311 | f"sqlglot ParseError during conversion: {e}. SQL: {sql[:150]}...") 312 | if 'AUTO_INCREMENT' in sql.upper(): 313 | modified_sql = re.sub( 314 | r'AUTO_INCREMENT', 'AUTOINCREMENT', sql, flags=re.IGNORECASE) 315 | modified_sql = re.sub(r'(\w+)\s+(?:INT|INTEGER)\s+PRIMARY\s+KEY\s+AUTOINCREMENT', 316 | r'\1 INTEGER PRIMARY KEY AUTOINCREMENT', modified_sql, flags=re.IGNORECASE) 317 | logger.debug("Applied manual AUTO_INCREMENT fix attempt.") 318 | return modified_sql 319 | return sql 320 | except Exception as e: 321 | logger.error( 322 | f"Unexpected error during sqlglot conversion: {e}", exc_info=DEBUG_MODE) 323 | return sql 324 | 325 | 326 | def fix_case_sensitivity_in_sql(conn: sqlite3.Connection, sql: str) -> str: 327 | if not sql: 328 | return sql 329 | 330 | corrected_sql = sql 331 | all_db_tables = list_all_tables(conn) 332 | if not all_db_tables: 333 | return sql 334 | 335 | table_case_map = {t.lower(): t for t in all_db_tables} 336 | 337 | referenced_tables = extract_tables_from_query(corrected_sql) 338 | needs_table_fix = False 339 | for table in referenced_tables: 340 | table_lower = table.lower() 341 | if table not in table_case_map.values() and table_lower in table_case_map: 342 | correct_case_table = table_case_map[table_lower] 343 | logger.debug( 344 | f"Case Fix: Replacing table '{table}' with '{correct_case_table}'") 345 | corrected_sql = re.sub(r'\b' + re.escape(table) + r'\b', 346 | correct_case_table, corrected_sql, flags=re.IGNORECASE) 347 | needs_table_fix = True 348 | 349 | if needs_table_fix: 350 | logger.debug( 351 | f"SQL after table case correction: {corrected_sql[:150]}...") 352 | 353 | current_referenced_tables = extract_tables_from_query(corrected_sql) 354 | needs_column_fix = False 355 | 356 | try: 357 | parsed_exp = parse(corrected_sql, read="sqlite") 358 | 359 | if isinstance(parsed_exp, list): 360 | all_col_refs = [] 361 | for expr in parsed_exp: 362 | if hasattr(expr, 'find_all'): 363 | all_col_refs.extend(expr.find_all(Column)) 364 | else: 365 | all_col_refs = parsed_exp.find_all(Column) 366 | 367 | for col_exp in all_col_refs: 368 | col_name = col_exp.name 369 | table_alias_or_name = col_exp.table 370 | 371 | target_table = None 372 | if table_alias_or_name: 373 | if table_alias_or_name.lower() in table_case_map: 374 | target_table = table_case_map[table_alias_or_name.lower()] 375 | else: 376 | pass 377 | 378 | if target_table: 379 | db_columns = get_column_names(conn, target_table) 380 | col_case_map = {c.lower(): c for c in db_columns} 381 | if col_name not in db_columns and col_name.lower() in col_case_map: 382 | correct_case_col = col_case_map[col_name.lower()] 383 | logger.debug( 384 | f"Case Fix: Replacing column '{table_alias_or_name}.{col_name}' with '{table_alias_or_name}.{correct_case_col}'") 385 | pattern = r'\b' + \ 386 | re.escape(table_alias_or_name) + \ 387 | r'\s*\.\s*' + re.escape(col_name) + r'\b' 388 | replacement = f"{table_alias_or_name}.{correct_case_col}" 389 | corrected_sql = re.sub( 390 | pattern, replacement, corrected_sql, flags=re.IGNORECASE) 391 | needs_column_fix = True 392 | 393 | elif not table_alias_or_name: 394 | possible_corrections = [] 395 | for ref_table_name_lower in current_referenced_tables: 396 | actual_ref_table = table_case_map.get( 397 | ref_table_name_lower.lower()) 398 | if actual_ref_table: 399 | db_columns = get_column_names(conn, actual_ref_table) 400 | col_case_map = {c.lower(): c for c in db_columns} 401 | if col_name not in db_columns and col_name.lower() in col_case_map: 402 | possible_corrections.append( 403 | col_case_map[col_name.lower()]) 404 | 405 | if len(possible_corrections) == 1: 406 | correct_case_col = possible_corrections[0] 407 | logger.debug( 408 | f"Case Fix: Replacing unqualified column '{col_name}' with '{correct_case_col}'") 409 | pattern = r'(? 1: 414 | logger.warning( 415 | f"Ambiguous case correction for unqualified column '{col_name}'. Found in multiple tables. Skipping.") 416 | 417 | except ParseError as e: 418 | logger.warning( 419 | f"sqlglot failed to parse for column case fixing: {e}. Column fix might be incomplete.") 420 | except Exception as e: 421 | logger.error( 422 | f"Unexpected error during column case fixing: {e}", exc_info=DEBUG_MODE) 423 | 424 | if needs_column_fix: 425 | logger.debug( 426 | f"SQL after column case correction: {corrected_sql[:150]}...") 427 | 428 | return corrected_sql 429 | 430 | 431 | def fix_ambiguous_columns(sql: str, conn: Optional[sqlite3.Connection] = None) -> str: 432 | if " JOIN " not in sql.upper(): 433 | return sql 434 | 435 | try: 436 | parsed_exp = parse(sql, read="sqlite") 437 | common_ambiguous = {'id', 'name', 'date', 'code', 'created_at', 'updated_at', 438 | 'description', 'status', 'type', 'price', 'quantity', 'amount'} 439 | first_table_alias = None 440 | 441 | if isinstance(parsed_exp, list): 442 | tables = [] 443 | for expr in parsed_exp: 444 | if hasattr(expr, 'find_all'): 445 | for node in expr.find_all(): 446 | if hasattr(node, 'name') and hasattr(node, 'is_table') and node.is_table: 447 | tables.append(node) 448 | else: 449 | tables = [node for node in parsed_exp.find_all() 450 | if hasattr(node, 'name') and hasattr(node, 'is_table') and node.is_table] 451 | 452 | if tables: 453 | first_table_alias = tables[0].alias_or_name 454 | 455 | if not first_table_alias: 456 | return sql 457 | 458 | fixed_sql = sql 459 | modified = False 460 | 461 | if isinstance(parsed_exp, list): 462 | all_col_refs = [] 463 | for expr in parsed_exp: 464 | if hasattr(expr, 'find_all'): 465 | all_col_refs.extend(expr.find_all(Column)) 466 | else: 467 | all_col_refs = parsed_exp.find_all(Column) 468 | 469 | for col_exp in all_col_refs: 470 | if not col_exp.table and col_exp.name.lower() in common_ambiguous: 471 | logger.debug( 472 | f"Ambiguity Fix: Qualifying '{col_exp.name}' with '{first_table_alias}'") 473 | pattern = r'(? Tuple[str, float]: 497 | if not error_msg: 498 | return ERR_OTHER, 0.0 499 | error_lower = error_msg.lower() 500 | if DEBUG_MODE: 501 | logger.debug(f"Categorizing SQL error: {error_msg}") 502 | 503 | if "syntax error" in error_lower: 504 | return ERR_SYNTAX, 0.0 505 | if "no such table" in error_lower: 506 | return ERR_MISSING_TABLE, 0.0 507 | if "no such column" in error_lower: 508 | return ERR_MISSING_COLUMN, 0.1 509 | if "ambiguous column" in error_lower: 510 | return ERR_AMBIGUOUS_COLUMN, 0.2 511 | if "datatype mismatch" in error_lower: 512 | return ERR_TYPE_MISMATCH, 0.15 513 | if "constraint failed" in error_lower or "constraint violation" in error_lower: 514 | return ERR_CONSTRAINT, 0.1 515 | if "no such function" in error_lower: 516 | return ERR_FUNCTION, 0.05 517 | if "too many terms in compound select" in error_lower: 518 | return ERR_SYNTAX, 0.0 519 | if "subquery returned more than 1 row" in error_lower: 520 | return ERR_EXECUTION, 0.1 521 | 522 | return ERR_OTHER, 0.0 523 | 524 | 525 | def strict_format_reward_func(prompts, completions, references=None, **kwargs) -> list[float]: 526 | strict_pattern = r"(.+?)\s*(.+?)" 527 | base_reward = REWARD_WEIGHTS.get("format", 1.0) 528 | rewards = [] 529 | for completion in completions: 530 | response_text = _get_response_text(completion) 531 | match = re.search(strict_pattern, response_text, 532 | re.IGNORECASE | re.DOTALL) 533 | rewards.append(base_reward if match else 0.0) 534 | return rewards 535 | 536 | 537 | def soft_format_reward_func(prompts, completions, references=None, **kwargs) -> list[float]: 538 | soft_pattern = r"(.*?)\s*(.*?)" 539 | base_reward = REWARD_WEIGHTS.get("format", 1.0) 540 | rewards = [] 541 | for completion in completions: 542 | response_text = _get_response_text(completion) 543 | match = re.search(soft_pattern, response_text, 544 | re.IGNORECASE | re.DOTALL) 545 | rewards.append(base_reward if match else 0.0) 546 | return rewards 547 | 548 | 549 | def extract_tables_columns(sql_context: str) -> tuple[set[str], set[str]]: 550 | tables = set() 551 | columns = set() 552 | if not sql_context: 553 | return tables, columns 554 | 555 | create_table_pattern = r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`\"\[]?(\w+)[`\"\]]?)\s*\((.*?)\);" 556 | create_view_pattern = r"CREATE\s+VIEW\s+(?:[`\"\[]?(\w+)[`\"\]]?)\s+AS" 557 | column_pattern = r"^\s*([`\"\[]?\w+[`\"\]]?)" 558 | 559 | try: 560 | statements = sqlparse.split(sql_context) 561 | for stmt in statements: 562 | stmt_clean = stmt.strip() 563 | table_match = re.search( 564 | create_table_pattern, stmt_clean, re.IGNORECASE | re.DOTALL | re.MULTILINE) 565 | if table_match: 566 | table_name = table_match.group(1).lower() 567 | tables.add(table_name) 568 | cols_text = table_match.group(2) 569 | for part in re.split(r',(?![^\(]*\))', cols_text): 570 | col_match = re.match(column_pattern, part.strip()) 571 | if col_match: 572 | columns.add(col_match.group(1).strip('`"[]').lower()) 573 | view_match = re.search(create_view_pattern, 574 | stmt_clean, re.IGNORECASE) 575 | if view_match: 576 | view_name = view_match.group(1).lower() 577 | tables.add(view_name) 578 | 579 | except Exception as e: 580 | logger.warning(f"Could not parse schema elements from context: {e}") 581 | 582 | return tables, columns 583 | 584 | 585 | def reasoning_quality_reward(prompts, completions, references=None, **kwargs) -> list[float]: 586 | rewards = [] 587 | schema_cache = {} 588 | 589 | for i, completion in enumerate(completions): 590 | response_text = _get_response_text(completion) 591 | reasoning = extract_reasoning(response_text) 592 | reward_components = {} 593 | 594 | if not reasoning: 595 | rewards.append(0.0) 596 | continue 597 | 598 | reasoning_lower = reasoning.lower() 599 | words = reasoning.split() 600 | lines = [line for line in reasoning.split("\n") if line.strip()] 601 | 602 | len_score = 0.0 603 | if len(words) >= 50: 604 | len_score = 0.20 605 | elif len(words) >= 25: 606 | len_score = 0.15 607 | elif len(words) >= 10: 608 | len_score = 0.10 609 | reward_components['length'] = len_score 610 | 611 | sql_terms = ["table", "column", "join", "select", "where", 612 | "group by", "order by", "filter", "aggregate", "schema", "database"] 613 | term_count = sum(1 for term in sql_terms if term in reasoning_lower) 614 | term_score = min(0.20, term_count * 0.03) 615 | reward_components['terms'] = term_score 616 | 617 | structure_score = 0.0 618 | if len(lines) >= 3: 619 | structure_score = 0.15 620 | elif len(lines) >= 2: 621 | structure_score = 0.10 622 | reward_components['structure'] = structure_score 623 | 624 | step_score = 0.0 625 | if re.search(r'(step 1|first|start|initial|begin)', reasoning_lower) and \ 626 | re.search(r'(step 2|next|then|second|final|last|subsequent)', reasoning_lower): 627 | step_score = 0.15 628 | reward_components['steps'] = step_score 629 | 630 | schema_mention_score = 0.0 631 | sql_context = None 632 | try: 633 | if references and i < len(references) and references[i] and isinstance(references[i], list) and references[i][0]: 634 | sql_context = references[i][0].get('sql_context') 635 | except IndexError: 636 | logger.warning(f"IndexError accessing references at index {i}") 637 | 638 | if sql_context: 639 | if i not in schema_cache: 640 | schema_cache[i] = extract_tables_columns( 641 | sql_context) if isinstance(sql_context, str) else (set(), set()) 642 | tables, columns = schema_cache[i] 643 | 644 | if tables or columns: 645 | mentioned_tables = sum(1 for t in tables if re.search( 646 | r'\b' + re.escape(t) + r'\b', reasoning_lower)) 647 | mentioned_cols = sum(1 for c in columns if re.search( 648 | r'\b' + re.escape(c) + r'\b', reasoning_lower)) 649 | total_mentions = mentioned_tables + mentioned_cols 650 | schema_mention_score = min(0.30, total_mentions * 0.05) 651 | reward_components['schema'] = schema_mention_score 652 | 653 | total_unscaled_reward = sum(reward_components.values()) 654 | final_reward = min(1.0, total_unscaled_reward) * \ 655 | REWARD_WEIGHTS.get("reasoning", 0.7) 656 | rewards.append(final_reward) 657 | if DEBUG_MODE: 658 | logger.debug( 659 | f"Reasoning Scores (Comp {i}): {reward_components} -> Total Raw: {total_unscaled_reward:.3f} -> Final: {final_reward:.3f}") 660 | 661 | return rewards 662 | 663 | 664 | def complexity_reward(prompts, completions, references, **kwargs) -> list[float]: 665 | rewards = [] 666 | base_weight = REWARD_WEIGHTS.get("complexity", 0.6) 667 | 668 | for i, completion in enumerate(completions): 669 | response_text = _get_response_text(completion) 670 | gen_sql = extract_sql(response_text) 671 | reward = 0.0 672 | 673 | gold_sql = "" 674 | try: 675 | if references and i < len(references) and references[i] and isinstance(references[i], list) and references[i][0]: 676 | gold_sql = references[i][0].get('gold_sql', '') 677 | except IndexError: 678 | logger.warning(f"IndexError accessing references at index {i}") 679 | 680 | if not gen_sql: 681 | rewards.append(0.0) 682 | continue 683 | 684 | try: 685 | gen_complexity = calculate_sql_complexity(gen_sql) 686 | 687 | if not gold_sql: 688 | reward = (0.4 if 1.5 <= gen_complexity <= 689 | 8.0 else 0.1) * base_weight 690 | if DEBUG_MODE: 691 | logger.debug( 692 | f"Complexity (Comp {i}): No Gold SQL. Gen={gen_complexity:.2f}. Reward={reward:.3f}") 693 | else: 694 | gold_complexity = calculate_sql_complexity(gold_sql) 695 | if gold_complexity < 0.1: 696 | rel_score = 1.0 if gen_complexity < 0.1 else 0.0 697 | else: 698 | ratio = max( 699 | 1e-3, min(gen_complexity / gold_complexity, 1e3)) 700 | log_ratio = torch.log(torch.tensor(ratio)) 701 | rel_score = torch.exp(-0.5 * (log_ratio**2)).item() 702 | reward = rel_score * base_weight 703 | if DEBUG_MODE: 704 | logger.debug( 705 | f"Complexity (Comp {i}): Gen={gen_complexity:.2f}, Gold={gold_complexity:.2f}, Ratio={ratio:.2f}, Score={rel_score:.3f}, Reward={reward:.3f}") 706 | 707 | rewards.append(max(0.0, reward)) 708 | 709 | except Exception as e: 710 | logger.warning(f"Error in complexity reward calculation: {e}") 711 | rewards.append(0.0) 712 | 713 | return rewards 714 | 715 | 716 | def dump_database_schema(conn): 717 | try: 718 | cursor = conn.cursor() 719 | tables = list_all_tables(conn) 720 | 721 | schema_info = {} 722 | 723 | for table in tables: 724 | cursor.execute(f"PRAGMA table_info({table})") 725 | columns = cursor.fetchall() 726 | 727 | column_info = [] 728 | for col in columns: 729 | col_id, name, col_type, not_null, default_val, is_pk = col 730 | col_desc = f"{name} ({col_type})" 731 | if is_pk: 732 | col_desc += " PRIMARY KEY" 733 | if not_null: 734 | col_desc += " NOT NULL" 735 | if default_val is not None: 736 | col_desc += f" DEFAULT {default_val}" 737 | column_info.append(col_desc) 738 | 739 | schema_info[table] = column_info 740 | 741 | cursor.execute(f"PRAGMA index_list({table})") 742 | indexes = cursor.fetchall() 743 | if indexes: 744 | schema_info[f"{table}_indexes"] = [] 745 | for idx in indexes: 746 | idx_name = idx[1] 747 | cursor.execute(f"PRAGMA index_info({idx_name})") 748 | idx_columns = cursor.fetchall() 749 | idx_cols = [info[2] for info in idx_columns] 750 | schema_info[f"{table}_indexes"].append( 751 | f"{idx_name} ({', '.join(idx_cols)})") 752 | 753 | return schema_info 754 | except Exception as e: 755 | logger.warning(f"Error dumping database schema: {e}") 756 | return {"error": str(e)} 757 | 758 | 759 | def execute_query_reward_func(prompts, completions, references, **kwargs) -> list[float]: 760 | rewards = [] 761 | 762 | for i, completion in enumerate(completions): 763 | response_text = _get_response_text(completion) 764 | gen_sql = extract_sql(response_text) 765 | 766 | gold_sql = "" 767 | sql_context = "" 768 | try: 769 | if references and i < len(references) and references[i] and isinstance(references[i], list) and references[i][0]: 770 | gold_sql = references[i][0].get('gold_sql', '') 771 | sql_context = references[i][0].get('sql_context', '') 772 | 773 | if DEBUG_MODE: 774 | logger.debug( 775 | f"Reference {i}: Gold SQL = {gold_sql[:100]}...") 776 | logger.debug( 777 | f"Reference {i}: Context SQL = {sql_context}") 778 | except IndexError: 779 | logger.warning(f"IndexError accessing references at index {i}") 780 | 781 | reward = 0.0 782 | 783 | if not gen_sql or not gold_sql or not sql_context: 784 | logger.warning( 785 | f"Missing SQL data for completion {i}: gen_sql={bool(gen_sql)}, gold_sql={bool(gold_sql)}, sql_context={bool(sql_context)}") 786 | rewards.append(reward) 787 | continue 788 | 789 | gold_type = identify_sql_statement_type(gold_sql) 790 | gen_type = identify_sql_statement_type(gen_sql) 791 | 792 | if gold_type == gen_type: 793 | reward += 0.1 * REWARD_WEIGHTS["sql_correctness"] 794 | 795 | if DEBUG_MODE: 796 | logger.debug(f"Gold SQL type: {gold_type}") 797 | logger.debug(f"Generated SQL type: {gen_type}") 798 | 799 | conn = None 800 | temp_db_file = None 801 | try: 802 | temp_db_file = tempfile.NamedTemporaryFile(delete=False).name 803 | conn = sqlite3.connect(temp_db_file, timeout=5) 804 | conn.isolation_level = None 805 | cursor = conn.cursor() 806 | 807 | create_table_statements = [] 808 | create_view_statements = [] 809 | other_statements = [] 810 | 811 | for stmt in sqlparse.split(sql_context): 812 | stmt = stmt.strip() 813 | if not stmt: 814 | continue 815 | 816 | stmt_upper = stmt.upper() 817 | if stmt_upper.startswith('CREATE TABLE'): 818 | create_table_statements.append(stmt) 819 | elif stmt_upper.startswith('CREATE VIEW'): 820 | create_view_statements.append(stmt) 821 | else: 822 | other_statements.append(stmt) 823 | 824 | if DEBUG_MODE: 825 | logger.debug(f"Found {len(create_table_statements)} CREATE TABLE statements, " 826 | f"{len(create_view_statements)} CREATE VIEW statements, and " 827 | f"{len(other_statements)} other statements") 828 | 829 | tables_created = [] 830 | for stmt in create_table_statements: 831 | try: 832 | table_match = re.search(r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?([^\s(]+)', 833 | stmt, re.IGNORECASE) 834 | table_name = table_match.group(1).strip( 835 | '`"[]') if table_match else "unknown" 836 | 837 | converted_stmt = convert_sql_to_sqlite(stmt) 838 | 839 | if DEBUG_MODE: 840 | logger.debug( 841 | f"Creating table {table_name} with statement: {converted_stmt[:100]}...") 842 | 843 | cursor.execute(converted_stmt) 844 | tables_created.append(table_name) 845 | 846 | exists_exact, exists_case_insensitive, correct_case = check_table_exists( 847 | conn, table_name) 848 | if exists_exact: 849 | if DEBUG_MODE: 850 | logger.debug( 851 | f"Table {table_name} created successfully") 852 | schema = get_table_schema(conn, table_name) 853 | logger.debug(f"Schema for {table_name}: {schema}") 854 | else: 855 | logger.warning( 856 | f"Table {table_name} creation failed silently") 857 | 858 | except sqlite3.Error as e: 859 | logger.warning(f"Error in CREATE TABLE statement: {e}") 860 | logger.warning( 861 | f"Table name: {table_name if 'table_name' in locals() else 'unknown'}") 862 | logger.warning(f"Original statement: {stmt[:200]}...") 863 | logger.warning(f"Converted statement: {converted_stmt[:200]}..." if 'converted_stmt' in locals( 864 | ) else "conversion failed") 865 | 866 | views_created = [] 867 | for stmt in create_view_statements: 868 | try: 869 | view_match = re.search( 870 | r'CREATE\s+VIEW\s+([^\s(]+)', stmt, re.IGNORECASE) 871 | view_name = view_match.group(1).strip( 872 | '`"[]') if view_match else "unknown" 873 | 874 | converted_stmt = convert_sql_to_sqlite(stmt) 875 | 876 | if DEBUG_MODE: 877 | logger.debug( 878 | f"Creating view {view_name} with statement: {converted_stmt[:100]}...") 879 | 880 | cursor.execute(converted_stmt) 881 | views_created.append(view_name) 882 | 883 | exists_exact, exists_case_insensitive, correct_case = check_table_exists( 884 | conn, view_name) 885 | if exists_exact: 886 | if DEBUG_MODE: 887 | logger.debug( 888 | f"View {view_name} created successfully") 889 | else: 890 | logger.warning( 891 | f"View {view_name} creation failed silently") 892 | 893 | except sqlite3.Error as e: 894 | logger.warning(f"Error in CREATE VIEW statement: {e}") 895 | logger.warning( 896 | f"View name: {view_name if 'view_name' in locals() else 'unknown'}") 897 | logger.warning(f"Original statement: {stmt[:200]}...") 898 | logger.warning(f"Converted statement: {converted_stmt[:200]}..." if 'converted_stmt' in locals( 899 | ) else "conversion failed") 900 | 901 | for stmt in other_statements: 902 | try: 903 | is_insert_like = stmt.upper().startswith( 904 | "INSERT") or "INSERT INTO" in stmt.upper() 905 | 906 | converted_stmt = convert_sql_to_sqlite(stmt) 907 | 908 | if DEBUG_MODE and is_insert_like: 909 | logger.debug( 910 | f"Executing insert-like statement: {converted_stmt[:100]}...") 911 | 912 | cursor.execute(converted_stmt) 913 | except sqlite3.Error as e: 914 | logger.warning(f"Error in non-CREATE statement: {e}") 915 | logger.warning(f"Statement causing error: {stmt[:200]}...") 916 | 917 | if DEBUG_MODE: 918 | schema_info = dump_database_schema(conn) 919 | logger.debug(f"Database schema after setup: {schema_info}") 920 | 921 | all_tables = list_all_tables(conn) 922 | logger.debug(f"All tables in database: {all_tables}") 923 | 924 | referenced_tables = extract_tables_from_query(gen_sql) 925 | if DEBUG_MODE: 926 | logger.debug( 927 | f"Tables referenced in generated query: {referenced_tables}") 928 | 929 | for table in referenced_tables: 930 | exists_exact, exists_case_insensitive, correct_case = check_table_exists( 931 | conn, table) 932 | if exists_exact: 933 | logger.debug( 934 | f"Table '{table}' referenced in query exists exactly as specified") 935 | elif exists_case_insensitive: 936 | logger.debug( 937 | f"Table '{table}' exists but with different case: '{correct_case}'") 938 | else: 939 | logger.debug( 940 | f"Table '{table}' does not exist in any case form") 941 | 942 | existing_tables = list_all_tables(conn) 943 | existing_tables_lower = [t.lower() for t in existing_tables] 944 | missing_tables = [table for table in referenced_tables if table.lower( 945 | ) not in existing_tables_lower] 946 | case_mismatch_tables = [ 947 | table for table in referenced_tables if table not in existing_tables and table.lower() in existing_tables_lower] 948 | 949 | if case_mismatch_tables: 950 | logger.warning( 951 | f"Case-mismatch in table references: {case_mismatch_tables}") 952 | 953 | case_mapping = {t.lower(): t for t in existing_tables} 954 | 955 | for wrong_case in case_mismatch_tables: 956 | correct_case = case_mapping[wrong_case.lower()] 957 | logger.debug( 958 | f"Fixing case: '{wrong_case}' → '{correct_case}'") 959 | 960 | gen_sql = re.sub(r'\b' + re.escape(wrong_case) + r'\b', 961 | correct_case, 962 | gen_sql, 963 | flags=re.IGNORECASE) 964 | 965 | logger.debug( 966 | f"Adjusted SQL with correct case: {gen_sql[:200]}...") 967 | 968 | if missing_tables: 969 | logger.warning( 970 | f"Tables genuinely missing (not just case mismatch): {missing_tables}") 971 | 972 | if gold_type == "SELECT" and gen_type == "SELECT": 973 | try: 974 | fixed_gen_sql = fix_ambiguous_columns(gen_sql) 975 | 976 | if fixed_gen_sql != gen_sql: 977 | logger.debug( 978 | f"Fixed ambiguous columns in generated SQL") 979 | logger.debug(f"Original SQL: {gen_sql[:200]}...") 980 | logger.debug(f"Fixed SQL: {fixed_gen_sql[:200]}...") 981 | gen_sql = fixed_gen_sql 982 | 983 | converted_gold_sql = convert_sql_to_sqlite(gold_sql) 984 | logger.debug( 985 | f"Executing gold SQL: {converted_gold_sql[:200]}...") 986 | 987 | cursor.execute(converted_gold_sql) 988 | gold_columns = [ 989 | desc[0] for desc in cursor.description] if cursor.description else [] 990 | gold_result = cursor.fetchmany(1000) 991 | 992 | logger.debug( 993 | f"Gold SQL execution successful, returned {len(gold_result)} rows") 994 | if gold_result and len(gold_result) > 0: 995 | logger.debug( 996 | f"First row of gold result: {gold_result[0]}") 997 | 998 | gen_sql_fixed = fix_case_sensitivity_in_sql(conn, gen_sql) 999 | 1000 | if gen_sql_fixed != gen_sql: 1001 | logger.debug( 1002 | f"Fixed case sensitivity issues in generated SQL") 1003 | gen_sql = gen_sql_fixed 1004 | 1005 | converted_gen_sql = convert_sql_to_sqlite(gen_sql) 1006 | logger.debug( 1007 | f"Executing generated SQL: {converted_gen_sql[:200]}...") 1008 | 1009 | cursor.execute(converted_gen_sql) 1010 | gen_columns = [ 1011 | desc[0] for desc in cursor.description] if cursor.description else [] 1012 | gen_result = cursor.fetchmany(1000) 1013 | 1014 | logger.debug( 1015 | f"Generated SQL execution successful, returned {len(gen_result)} rows") 1016 | if gen_result and len(gen_result) > 0: 1017 | logger.debug( 1018 | f"First row of generated result: {gen_result[0]}") 1019 | 1020 | base_reward = 0.3 * REWARD_WEIGHTS["sql_correctness"] 1021 | reward = base_reward 1022 | 1023 | gold_rows = set(tuple(row) for row in gold_result) 1024 | gen_rows = set(tuple(row) for row in gen_result) 1025 | 1026 | if gold_rows == gen_rows and gold_columns == gen_columns: 1027 | reward = REWARD_WEIGHTS["sql_correctness"] 1028 | logger.debug(f"Results and columns match exactly!") 1029 | elif gold_rows and gen_rows: 1030 | if gold_columns == gen_columns: 1031 | intersection = len( 1032 | gold_rows.intersection(gen_rows)) 1033 | union = len(gold_rows.union(gen_rows)) 1034 | jaccard = intersection / union if union > 0 else 0 1035 | else: 1036 | gold_cols_lower = [c.lower() for c in gold_columns] 1037 | gen_cols_lower = [c.lower() for c in gen_columns] 1038 | common_columns_indices = [] 1039 | 1040 | for i, gold_col in enumerate(gold_cols_lower): 1041 | if gold_col in gen_cols_lower: 1042 | j = gen_cols_lower.index(gold_col) 1043 | common_columns_indices.append((i, j)) 1044 | 1045 | if common_columns_indices: 1046 | gold_projected = [{i: row[i] for i, _ in common_columns_indices} 1047 | for row in gold_result] 1048 | gen_projected = [{j: row[j] for _, j in common_columns_indices} 1049 | for row in gen_result] 1050 | 1051 | gold_proj_rows = { 1052 | tuple(sorted(d.items())) for d in gold_projected} 1053 | gen_proj_rows = { 1054 | tuple(sorted(d.items())) for d in gen_projected} 1055 | 1056 | intersection = len( 1057 | gold_proj_rows.intersection(gen_proj_rows)) 1058 | union = len( 1059 | gold_proj_rows.union(gen_proj_rows)) 1060 | jaccard = intersection / union if union > 0 else 0 1061 | 1062 | if DEBUG_MODE: 1063 | logger.debug( 1064 | f"Similarity calculated on {len(common_columns_indices)} common columns") 1065 | else: 1066 | jaccard = 0.0 1067 | 1068 | row_count_ratio = min(len(gen_rows), len(gold_rows)) / max( 1069 | len(gen_rows), len(gold_rows)) if max(len(gen_rows), len(gold_rows)) > 0 else 0 1070 | 1071 | col_similarity = 0.0 1072 | if gold_columns and gen_columns: 1073 | gold_cols_set = set(c.lower() 1074 | for c in gold_columns) 1075 | gen_cols_set = set(c.lower() for c in gen_columns) 1076 | col_intersection = len( 1077 | gold_cols_set.intersection(gen_cols_set)) 1078 | col_union = len(gold_cols_set.union(gen_cols_set)) 1079 | col_similarity = col_intersection / col_union if col_union > 0 else 0 1080 | 1081 | data_accuracy = len(gold_rows.intersection( 1082 | gen_rows)) / len(gold_rows) if gold_rows else 0 1083 | 1084 | content_similarity = ( 1085 | 0.40 * jaccard + 1086 | 0.20 * row_count_ratio + 1087 | 0.25 * col_similarity + 1088 | 0.15 * data_accuracy 1089 | ) 1090 | 1091 | reward = REWARD_WEIGHTS["sql_correctness"] * \ 1092 | content_similarity 1093 | 1094 | if DEBUG_MODE: 1095 | logger.debug(f"Reward calculation: jaccard={jaccard:.3f}, row_ratio={row_count_ratio:.3f}, " + 1096 | f"col_sim={col_similarity:.3f}, data_acc={data_accuracy:.3f}, " + 1097 | f"content_sim={content_similarity:.3f}, final_reward={reward:.3f}") 1098 | 1099 | if intersection > 0 and reward < 0.3 * REWARD_WEIGHTS["sql_correctness"]: 1100 | reward = 0.3 * REWARD_WEIGHTS["sql_correctness"] 1101 | 1102 | if reward <= base_reward and gen_result is not None: 1103 | reward = max( 1104 | reward, 0.2 * REWARD_WEIGHTS["sql_correctness"]) 1105 | 1106 | except sqlite3.Error as e: 1107 | error_msg = str(e) 1108 | error_type, partial_credit = categorize_sql_error( 1109 | error_msg) 1110 | 1111 | if partial_credit > 0: 1112 | reward = partial_credit * \ 1113 | REWARD_WEIGHTS["sql_correctness"] 1114 | 1115 | logger.warning( 1116 | f"Error executing SELECT statement ({error_type}): {error_msg}") 1117 | logger.warning(f"Generated SQL: {gen_sql[:200]}...") 1118 | if 'converted_gen_sql' in locals(): 1119 | logger.warning( 1120 | f"Converted SQL: {converted_gen_sql[:200]}...") 1121 | 1122 | elif gen_type in ["INSERT", "UPDATE", "DELETE"]: 1123 | try: 1124 | if "JOIN" in gen_sql.upper() and gen_type != "SELECT": 1125 | logger.warning( 1126 | f"JOIN detected in {gen_type} statement - may cause issues") 1127 | 1128 | if gen_type == "INSERT": 1129 | table_match = re.search( 1130 | r'INSERT\s+INTO\s+([^\s(]+)', gen_sql, re.IGNORECASE) 1131 | if table_match: 1132 | main_table = table_match.group(1) 1133 | logger.debug( 1134 | f"Main table for INSERT: {main_table}") 1135 | elif gen_type == "UPDATE": 1136 | table_match = re.search( 1137 | r'UPDATE\s+([^\s(]+)', gen_sql, re.IGNORECASE) 1138 | if table_match: 1139 | main_table = table_match.group(1) 1140 | logger.debug( 1141 | f"Main table for UPDATE: {main_table}") 1142 | elif gen_type == "DELETE": 1143 | table_match = re.search( 1144 | r'DELETE\s+FROM\s+([^\s(]+)', gen_sql, re.IGNORECASE) 1145 | if table_match: 1146 | main_table = table_match.group(1) 1147 | logger.debug( 1148 | f"Main table for DELETE: {main_table}") 1149 | 1150 | if 'main_table' in locals(): 1151 | exists = check_table_exists(conn, main_table) 1152 | logger.debug( 1153 | f"Main table '{main_table}' exists: {exists}") 1154 | 1155 | gen_sql_fixed = fix_case_sensitivity_in_sql(conn, gen_sql) 1156 | 1157 | if gen_sql_fixed != gen_sql: 1158 | logger.debug( 1159 | f"Fixed case sensitivity issues in DML statement") 1160 | gen_sql = gen_sql_fixed 1161 | 1162 | converted_gen_sql = convert_sql_to_sqlite(gen_sql) 1163 | logger.debug( 1164 | f"Executing DML statement: {converted_gen_sql[:200]}...") 1165 | 1166 | cursor.execute(converted_gen_sql) 1167 | reward = 0.5 * REWARD_WEIGHTS["sql_correctness"] 1168 | 1169 | except sqlite3.Error as e: 1170 | error_msg = str(e) 1171 | logger.warning( 1172 | f"Error executing DML statement: {error_msg}") 1173 | logger.warning(f"Generated SQL: {gen_sql[:200]}...") 1174 | 1175 | if "no such table" in error_msg.lower(): 1176 | table_match = re.search( 1177 | r"no such table: (\w+)", error_msg, re.IGNORECASE) 1178 | if table_match: 1179 | missing_table = table_match.group(1) 1180 | logger.debug(f"Missing table: {missing_table}") 1181 | 1182 | all_tables = list_all_tables(conn) 1183 | logger.debug(f"Available tables: {all_tables}") 1184 | 1185 | case_mapping = {t.lower(): t for t in all_tables} 1186 | 1187 | if missing_table.lower() in case_mapping: 1188 | correct_case = case_mapping[missing_table.lower( 1189 | )] 1190 | logger.debug( 1191 | f"Case mismatch detected! '{missing_table}' vs '{correct_case}'") 1192 | 1193 | corrected_sql = re.sub(r'\b' + re.escape(missing_table) + r'\b', 1194 | correct_case, 1195 | gen_sql, 1196 | flags=re.IGNORECASE) 1197 | 1198 | logger.debug( 1199 | f"Corrected SQL: {corrected_sql[:200]}...") 1200 | 1201 | try: 1202 | converted_corrected = convert_sql_to_sqlite( 1203 | corrected_sql) 1204 | cursor.execute(converted_corrected) 1205 | reward = 0.4 * \ 1206 | REWARD_WEIGHTS["sql_correctness"] 1207 | logger.debug( 1208 | f"Execution successful after case correction!") 1209 | except sqlite3.Error as e2: 1210 | logger.warning( 1211 | f"Still failed after case correction: {e2}") 1212 | logger.debug( 1213 | f"New error after case correction: {e2}") 1214 | logger.debug( 1215 | f"Converted corrected SQL: {converted_corrected[:200]}...") 1216 | 1217 | rewards.append(reward) 1218 | 1219 | except Exception as e: 1220 | logger.warning(f"Error in execution reward calculation: {e}") 1221 | import traceback 1222 | logger.warning(f"Stack trace: {traceback.format_exc()}") 1223 | rewards.append(reward) 1224 | finally: 1225 | logger.debug(f"Final reward for completion {i}: {reward}") 1226 | 1227 | if conn: 1228 | try: 1229 | conn.close() 1230 | except: 1231 | pass 1232 | 1233 | try: 1234 | if temp_db_file and os.path.exists(temp_db_file): 1235 | os.unlink(temp_db_file) 1236 | except: 1237 | pass 1238 | 1239 | return rewards 1240 | 1241 | 1242 | if __name__ == "__main__": 1243 | DEBUG_MODE = True 1244 | log_level = logging.DEBUG 1245 | logging.getLogger().setLevel(log_level) 1246 | for handler in logging.getLogger().handlers: 1247 | handler.setLevel(log_level) 1248 | logger.info("Running example usage...") 1249 | 1250 | prompts_example = ["Show names of dogs older than 5 years."] 1251 | completions_example = [ 1252 | "Find dogs table. Filter by age > 5. Select name column.\nSELECT name FROM dogs WHERE age > 5;" 1253 | ] 1254 | references_example = [[{ 1255 | "gold_sql": "SELECT name FROM dogs WHERE age > 5 ORDER BY dog_id;", 1256 | "sql_context": """ 1257 | CREATE TABLE dogs (dog_id INTEGER PRIMARY KEY, name TEXT, age INTEGER); 1258 | INSERT INTO dogs (name, age) VALUES ('Buddy', 7); 1259 | INSERT INTO dogs (name, age) VALUES ('Lucy', 4); 1260 | INSERT INTO dogs (name, age) VALUES ('Max', 8); 1261 | """, 1262 | "question": prompts_example[0] 1263 | }]] 1264 | 1265 | print("\n--- Testing execute_query_reward_func (Order Ignored) ---") 1266 | exec_rewards_order_ignored = execute_query_reward_func( 1267 | prompts_example, completions_example, references_example, 1268 | source_dialect_dataset="mysql", 1269 | source_dialect_generated="postgresql", 1270 | order_matters=False, 1271 | validate_schema=True 1272 | ) 1273 | print(f"Execution Rewards (Order Ignored): {exec_rewards_order_ignored}") 1274 | 1275 | print("\n--- Testing execute_query_reward_func (Order Matters) ---") 1276 | exec_rewards_order_matters = execute_query_reward_func( 1277 | prompts_example, completions_example, references_example, 1278 | source_dialect_dataset="mysql", 1279 | source_dialect_generated="postgresql", 1280 | order_matters=True, 1281 | validate_schema=True 1282 | ) 1283 | print(f"Execution Rewards (Order Matters): {exec_rewards_order_matters}") 1284 | 1285 | print("\n--- Testing Format Rewards ---") 1286 | strict_format_rewards = strict_format_reward_func( 1287 | prompts_example, completions_example) 1288 | soft_format_rewards = soft_format_reward_func( 1289 | prompts_example, completions_example) 1290 | print(f"Strict Format Rewards: {strict_format_rewards}") 1291 | print(f"Soft Format Rewards: {soft_format_rewards}") 1292 | 1293 | print("\n--- Testing Complexity Reward ---") 1294 | complexity_rewards = complexity_reward( 1295 | prompts_example, completions_example, references_example) 1296 | print(f"Complexity Rewards: {complexity_rewards}") 1297 | 1298 | print("\n--- Testing Reasoning Quality Reward ---") 1299 | reasoning_rewards = reasoning_quality_reward( 1300 | prompts_example, completions_example, references_example) 1301 | print(f"Reasoning Quality Rewards: {reasoning_rewards}") 1302 | --------------------------------------------------------------------------------