├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── datatype ├── models-cli ├── Cargo.toml ├── README.md └── src │ ├── bin │ └── models.rs │ ├── database.rs │ ├── generate.rs │ ├── lib.rs │ ├── migrate.rs │ ├── migration.rs │ └── opt.rs ├── models-parser ├── Cargo.toml ├── HEADER ├── LICENSE.TXT ├── README.md ├── src │ ├── ast │ │ ├── data_type.rs │ │ ├── ddl.rs │ │ ├── expression │ │ │ ├── display.rs │ │ │ └── mod.rs │ │ ├── mod.rs │ │ ├── operator.rs │ │ ├── query.rs │ │ ├── statement │ │ │ ├── display.rs │ │ │ └── mod.rs │ │ └── value.rs │ ├── dialect │ │ ├── ansi.rs │ │ ├── generic.rs │ │ ├── hive.rs │ │ ├── keywords.rs │ │ ├── mod.rs │ │ ├── mssql.rs │ │ ├── mysql.rs │ │ ├── postgresql.rs │ │ ├── snowflake.rs │ │ └── sqlite.rs │ ├── lib.rs │ ├── parser.rs │ ├── test_utils.rs │ └── tokenizer.rs └── tests │ ├── common.rs │ ├── hive.rs │ ├── mssql.rs │ ├── mysql.rs │ ├── postgres.rs │ ├── queries │ └── tpch │ │ ├── 1.sql │ │ ├── 10.sql │ │ ├── 11.sql │ │ ├── 12.sql │ │ ├── 13.sql │ │ ├── 14.sql │ │ ├── 15.sql │ │ ├── 16.sql │ │ ├── 17.sql │ │ ├── 18.sql │ │ ├── 19.sql │ │ ├── 2.sql │ │ ├── 20.sql │ │ ├── 21.sql │ │ ├── 22.sql │ │ ├── 3.sql │ │ ├── 4.sql │ │ ├── 5.sql │ │ ├── 6.sql │ │ ├── 7.sql │ │ ├── 8.sql │ │ └── 9.sql │ ├── regression.rs │ ├── snowflake.rs │ ├── sqlite.rs │ └── test_utils │ └── mod.rs ├── models-proc-macro ├── Cargo.toml └── src │ ├── getters.rs │ ├── lib.rs │ ├── migration_generation.rs │ ├── model │ ├── column │ │ ├── default.rs │ │ └── mod.rs │ ├── constraint.rs │ └── mod.rs │ └── prelude.rs └── models ├── Cargo.toml └── src ├── dialect.rs ├── error.rs ├── lib.rs ├── postgres.rs ├── prelude.rs ├── private ├── mod.rs └── scheduler │ ├── driver │ ├── actions │ │ ├── action │ │ │ ├── mod.rs │ │ │ └── temp_move.rs │ │ ├── compare.rs │ │ ├── crud.rs │ │ ├── inner.rs │ │ └── mod.rs │ ├── migration.rs │ ├── mod.rs │ ├── queue │ │ ├── mod.rs │ │ └── sorter.rs │ ├── report.rs │ └── schema.rs │ ├── mod.rs │ └── table │ ├── column.rs │ ├── constraint.rs │ └── mod.rs ├── rusqlite.rs ├── sqlx.rs ├── tests └── mod.rs ├── tokio_postgres.rs └── types ├── bytes.rs ├── chrono_impl.rs ├── json.rs ├── mod.rs ├── serial.rs ├── time.rs ├── var_binary.rs └── var_char.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | */.* 4 | models-tests/ 5 | *.DS_Store -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | 3 | members = [ 4 | "models", 5 | "models-cli", 6 | "models-parser", 7 | "models-tests", 8 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | Models is a SQL migration management tool. It supports PostgreSQL, MySQL, and SQLite. 3 | 4 | 5 | # Quick Start 6 | 7 | install the CLI by running the following command: 8 | ``` 9 | $ cargo install models-cli 10 | ``` 11 | 12 | Now run the following command to create an environment file with the `DATABASE_URL` variable set: 13 | ``` 14 | $ echo "DATABASE_URL=sqlite://database.db" > .env 15 | ``` 16 | Alternatively it can be set as a environment variable with the following command: 17 | ``` 18 | $ export DATABASE_URL=sqlite://database.db 19 | ``` 20 | We now can create the database running the following command: 21 | ``` 22 | $ models database create 23 | ``` 24 | This command will have created an SQLite file called `database.db`. 25 | You can now derive the `Model` trait on your structures, 26 | and `models` will manage the migrations for you. For example, write at `src/main.rs`: 27 | ```rust 28 | #![allow(dead_code)] 29 | use models::Model; 30 | 31 | #[derive(Model)] 32 | struct Profile { 33 | #[primary_key] 34 | id: i32, 35 | #[unique] 36 | email: String, 37 | password: String, 38 | is_admin: bool, 39 | } 40 | 41 | #[derive(Model)] 42 | struct Post { 43 | #[primary_key] 44 | id: i32, 45 | #[foreign_key(Profile.id)] 46 | author: i32, 47 | #[default("")] 48 | title: String, 49 | content: String, 50 | } 51 | 52 | #[derive(Model)] 53 | struct PostLike { 54 | #[foreign_key(Profile.id, on_delete="cascade")] 55 | #[primary_key(post_id)] 56 | profile_id: i32, 57 | #[foreign_key(Post.id, on_delete="cascade")] 58 | post_id: i32, 59 | } 60 | 61 | #[derive(Model)] 62 | struct CommentLike { 63 | #[foreign_key(Profile.id)] 64 | #[primary_key(comment_id)] 65 | profile_id: i32, 66 | #[foreign_key(Comment.id)] 67 | comment_id: i32, 68 | is_dislike: bool, 69 | } 70 | 71 | #[derive(Model)] 72 | struct Comment { 73 | #[primary_key] 74 | id: i32, 75 | #[foreign_key(Profile.id)] 76 | author: i32, 77 | #[foreign_key(Post.id)] 78 | post: i32, 79 | } 80 | fn main() {} 81 | ``` 82 | 83 | If you now run the following command, your migrations should be automatically created. 84 | ``` 85 | $ models generate 86 | ``` 87 | The output should look like this: 88 | ``` 89 | Generated: migrations/1632280793452 profile 90 | Generated: migrations/1632280793459 post 91 | Generated: migrations/1632280793465 postlike 92 | Generated: migrations/1632280793471 comment 93 | Generated: migrations/1632280793476 commentlike 94 | ``` 95 | You can check out the generated migrations at the `migrations/` folder. 96 | To execute these migrations you can execute the following command: 97 | ``` 98 | models migrate run 99 | ``` 100 | The output should look like this: 101 | ``` 102 | Applied 1631716729974/migrate profile (342.208µs) 103 | Applied 1631716729980/migrate post (255.958µs) 104 | Applied 1631716729986/migrate comment (287.792µs) 105 | Applied 1631716729993/migrate postlike (349.834µs) 106 | Applied 1631716729998/migrate commentlike (374.625µs) 107 | ``` 108 | If we later modify those structures in our application, we can generate new migrations to update the tables. 109 | 110 | ## Reverting migration 111 | Models can generate down migrations with the `-r` flag. Note that simple and reversible migrations cannot be mixed: 112 | ``` 113 | $ models generate -r 114 | ``` 115 | In order to revert the last migration executed you can run: 116 | ``` 117 | $ models migrate revert 118 | ``` 119 | If you later want to see which migrations are yet to be applied you can also excecute: 120 | ``` 121 | $ models migrate info 122 | ``` 123 | Applied migrations need to be reverted before they can be deleted. 124 | ## Avaibale Attributes 125 | ### primary_key 126 | It's used to mark the primary key fo the table. 127 | ```rust 128 | #[primary_key] 129 | id: i32, 130 | ``` 131 | for tables with multicolumn primary keys, the following syntax is used: 132 | ```rust 133 | #[primary_key(second_id)] 134 | first_id: i32, 135 | second_id: i32, 136 | ``` 137 | This is equivalent to: 138 | ```sql 139 | PRIMARY KEY (first_id, second_id), 140 | ``` 141 | 142 | ### foreign_key 143 | It is used to mark a foreign key constraint. 144 | ```rust 145 | #[foreign_key(Profile.id)] 146 | profile: i32, 147 | ``` 148 | It can also specify `on_delete` and `on_update` constraints: 149 | ```rust 150 | #[foreign_key(Profile.id, on_delete="cascade")] 151 | profile_id: i32, 152 | ``` 153 | This is equivalent to: 154 | ```sql 155 | FOREIGN KEY (profile_id) REFERENCES profile (id) ON DELETE CASCADE, 156 | ``` 157 | ### default 158 | It can be used to set a default value for a column. 159 | ```rust 160 | #[default(false)] // when using SQLite use 0 or 1 161 | is_admin: bool, 162 | #[default("")] 163 | text: String, 164 | #[default(0)] 165 | number: i32, 166 | ``` 167 | 168 | ### unique 169 | It is used to mark a unique constraint. 170 | ```rust 171 | #[unique] 172 | email: String, 173 | ``` 174 | For multicolumn unique constraints the following syntax is used: 175 | ```rust 176 | #[unique(post_id)] 177 | profile_id: String, 178 | post_id: i32, 179 | ``` 180 | This is equivalent to: 181 | ```sql 182 | UNIQUE (profile_id, post_id), 183 | ``` 184 | ## CLI Short cuts 185 | The CLI includes the following shortcuts: 186 | * `models database` -> `models db` 187 | * `models generate` -> `models gen` 188 | * `models migrate` -> `models mig` 189 | -------------------------------------------------------------------------------- /datatype: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tvallotton/models/87092ddd62492e8c5aa6be5a07f9bcfbc1b9ed84/datatype -------------------------------------------------------------------------------- /models-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "models-cli" 3 | version = "0.1.1" 4 | description = "Command-line utility for SQLx, the Rust SQL toolkit." 5 | edition = "2018" 6 | readme = "README.md" 7 | 8 | 9 | keywords = ["database", "postgres", "database-management", "migration"] 10 | categories = ["database", "command-line-utilities"] 11 | license = "MIT OR Apache-2.0" 12 | default-run = "models" 13 | authors = [ 14 | "Jesper Axelsson ", 15 | "Austin Bonander ", 16 | ] 17 | 18 | [[bin]] 19 | name = "models" 20 | path = "src/bin/models.rs" 21 | 22 | 23 | 24 | [dependencies] 25 | dotenv = "0.15" 26 | tokio = { version = "1.0.1", features = ["macros", "rt", "rt-multi-thread", "fs", "process", "io-std"] } 27 | sqlx = { version = "0.5.9", default-features = false, features = [ 28 | "runtime-async-std-native-tls", 29 | "migrate", 30 | "any", 31 | "offline", 32 | ] } 33 | futures = "0.3" 34 | # FIXME: we need to fix both of these versions until Clap 3.0 proper is released, then we can drop `clap_derive` 35 | # https://github.com/launchbadge/sqlx/issues/1378 36 | # https://github.com/clap-rs/clap/issues/2705 37 | chrono = "0.4" 38 | anyhow = "1.0" 39 | url = { version = "2.1.1", default-features = false } 40 | async-trait = "0.1.30" 41 | console = "0.14.1" 42 | promptly = "0.3.0" 43 | serde_json = "1.0.68" 44 | serde = { version = "1.0.130", features = ["derive"] } 45 | glob = "0.3.0" 46 | openssl = { version = "0.10.30", optional = true } 47 | # workaround for https://github.com/rust-lang/rust/issues/29497 48 | remove_dir_all = "0.7.0" 49 | regex = "1.5.4" 50 | structopt = "0.3.23" 51 | clap = "2.33.3" 52 | 53 | [features] 54 | default = ["postgres", "sqlite", "mysql"] 55 | 56 | # databases 57 | mysql = ["sqlx/mysql"] 58 | postgres = ["sqlx/postgres"] 59 | sqlite = ["sqlx/sqlite"] 60 | 61 | # workaround for musl + openssl issues 62 | openssl-vendored = ["openssl/vendored"] 63 | -------------------------------------------------------------------------------- /models-cli/README.md: -------------------------------------------------------------------------------- 1 | # Models CLI 2 | 3 | ## Installation 4 | To install the CLI use the following command: 5 | ``` 6 | $ cargo install models-cli 7 | ``` 8 | 9 | ## Usage 10 | There are three main commands: `database`, `generate` and `migrate`. 11 | 12 | ### database 13 | it can be abbreviated as `db`. It includes the subcomands: 14 | * `create`: Creates the database specified in your DATABASE_URL. 15 | * `drop`: Drops the database specified in your DATABASE_URL. 16 | * `reset`: Drops the database specified in your DATABASE_URL, re-creates it, and runs any pending migrations. 17 | * `setup`: Creates the database specified in your DATABASE_URL and runs any pending migrations. 18 | 19 | ### generate 20 | It is used to generate migrations. It can be used to generate down migrations as well if the `-r` flag is enabled. 21 | The `--source` variable can be used to specify the migrations directory. 22 | The `--table` variable can be used to filter the names of the tables to target in the generation. 23 | 24 | ### migrate 25 | * `add`: Create a new migration with the given description, and the current time as the version. 26 | * `info`: List all available migrations and their status. 27 | * `revert`: Revert the latest migration with a down file. 28 | * `run`: Run all pending migrations. -------------------------------------------------------------------------------- /models-cli/src/bin/models.rs: -------------------------------------------------------------------------------- 1 | use console::style; 2 | use dotenv::dotenv; 3 | use models_cli::Opt; 4 | use structopt::StructOpt; 5 | 6 | #[tokio::main] 7 | async fn main() { 8 | dotenv().ok(); 9 | 10 | // no special handling here 11 | if let Err(error) = models_cli::run(Opt::from_args()).await { 12 | println!("{}: {}", style("error").bold().red(), error); 13 | std::process::exit(1); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /models-cli/src/database.rs: -------------------------------------------------------------------------------- 1 | use crate::migrate; 2 | use console::style; 3 | use promptly::{prompt, ReadlineError}; 4 | use sqlx::any::Any; 5 | use sqlx::migrate::MigrateDatabase; 6 | 7 | pub async fn create(uri: &str) -> anyhow::Result<()> { 8 | if !Any::database_exists(uri).await? { 9 | Any::create_database(uri).await?; 10 | } 11 | 12 | Ok(()) 13 | } 14 | 15 | pub async fn drop(uri: &str, confirm: bool) -> anyhow::Result<()> { 16 | if confirm && !ask_to_continue(uri) { 17 | return Ok(()); 18 | } 19 | 20 | if Any::database_exists(uri).await? { 21 | Any::drop_database(uri).await?; 22 | } 23 | 24 | Ok(()) 25 | } 26 | 27 | pub async fn reset(migration_source: &str, uri: &str, confirm: bool) -> anyhow::Result<()> { 28 | drop(uri, confirm).await?; 29 | setup(migration_source, uri).await 30 | } 31 | 32 | pub async fn setup(migration_source: &str, uri: &str) -> anyhow::Result<()> { 33 | create(uri).await?; 34 | migrate::run(migration_source, uri, false, false).await 35 | } 36 | 37 | fn ask_to_continue(uri: &str) -> bool { 38 | loop { 39 | let r: Result = 40 | prompt(format!("Drop database at {}? (y/n)", style(uri).cyan())); 41 | match r { 42 | Ok(response) => { 43 | if response == "n" || response == "N" { 44 | return false; 45 | } else if response == "y" || response == "Y" { 46 | return true; 47 | } else { 48 | println!( 49 | "Response not recognized: {}\nPlease type 'y' or 'n' and press enter.", 50 | response 51 | ); 52 | } 53 | } 54 | Err(e) => { 55 | println!("{}", e); 56 | return false; 57 | } 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /models-cli/src/generate.rs: -------------------------------------------------------------------------------- 1 | use super::opt::GenerateOpt; 2 | use anyhow::{Error, Result}; 3 | use console::style; 4 | use serde::*; 5 | use serde_json::from_str; 6 | 7 | #[derive(Serialize, Deserialize)] 8 | 9 | struct MigrationError { 10 | kind: String, 11 | message: String, 12 | } 13 | #[derive(Serialize, Deserialize)] 14 | struct Output { 15 | success: Vec<(i64, String)>, 16 | error: Option, 17 | } 18 | 19 | impl Output { 20 | fn print(self, source: &str) -> Result<()> { 21 | for (num, name) in self.success { 22 | println!( 23 | "{}: {}/{}{}{}", 24 | style("Generated").bold().green(), 25 | style(source), 26 | style(num).cyan(), 27 | style("_").dim(), 28 | style(name) 29 | ) 30 | } 31 | 32 | if let Some(err) = self.error { 33 | Err(Error::msg(err.message)) 34 | } else { 35 | Ok(()) 36 | } 37 | } 38 | } 39 | 40 | pub async fn generate(opt: GenerateOpt) -> Result<()> { 41 | use anyhow::*; 42 | std::fs::create_dir_all(&opt.source).context("Unable to create migrations directory")?; 43 | opt.validate().await?; 44 | touch_any().await.ok(); 45 | 46 | if !builds(&opt.database_url, &opt.source).await { 47 | return Err(Error::msg( 48 | "could not compile project. No migrations were generated.", 49 | )); 50 | } 51 | let filter_tests = format!( 52 | "__models_generate_migration_{}", 53 | opt.table.as_deref().unwrap_or("") 54 | ); 55 | let output = tokio::process::Command::new("cargo") 56 | .arg("test") 57 | .arg("--") 58 | .arg("--nocapture") 59 | .arg(&filter_tests) 60 | .env("MODELS_GENERATE_MIGRATIONS", "true") 61 | .env("MIGRATIONS_DIR", &opt.source) 62 | .env("DATABASE_URL", &opt.database_url) 63 | .env("MODELS_GENERATE_DOWN", opt.reversible.to_string()) 64 | .output() 65 | .await 66 | .unwrap() 67 | .stdout; 68 | let output = String::from_utf8(output).unwrap(); 69 | let regex = regex::Regex::new("(.+)").unwrap(); 70 | 71 | if output.contains("running 0 tests") { 72 | if let Some(table) = &opt.table { 73 | println!("No models named {}.", table) 74 | } else { 75 | println!("No models in the application") 76 | } 77 | return Ok(()); 78 | } 79 | let x = regex.captures(&output).expect(&output); 80 | 81 | if let Some(json) = x.get(1) { 82 | from_str::(json.as_str()) 83 | .expect(json.as_str()) 84 | .print(&opt.source)?; 85 | } else { 86 | println!("Everything is up to date."); 87 | } 88 | touch_any().await.ok(); 89 | Ok(()) 90 | } 91 | 92 | async fn builds(database_url: &str, source: &str) -> bool { 93 | tokio::process::Command::new("cargo") 94 | .arg("build") 95 | .arg("--tests") 96 | .env("MODELS_GENERATE_MIGRATIONS", "true") 97 | .env("MIGRATIONS_DIR", database_url) 98 | .env("DATABASE_URL", source) 99 | .spawn() 100 | .unwrap() 101 | .wait() 102 | .await 103 | .unwrap() 104 | .success() 105 | } 106 | 107 | pub async fn touch_any() -> Result<()> { 108 | let mut listdir = tokio::fs::read_dir("src/").await?; 109 | while let Some(entry) = listdir.next_entry().await? { 110 | let file_name = entry.file_name(); 111 | let regex = regex::Regex::new(r".+\.rs")?; 112 | if regex.is_match(file_name.to_str().unwrap()) { 113 | // println!("{}", format!("src/{}", file_name.to_str().unwrap())); 114 | let success = tokio::process::Command::new("touch") 115 | .arg(&format!("src/{}", file_name.to_str().unwrap())) 116 | .spawn()? 117 | .wait() 118 | .await? 119 | .success(); 120 | if success { 121 | break; 122 | } 123 | } 124 | } 125 | Ok(()) 126 | } 127 | -------------------------------------------------------------------------------- /models-cli/src/lib.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use generate::generate; 3 | 4 | use crate::opt::{Command, DatabaseCommand, MigrateCommand}; 5 | 6 | mod database; 7 | 8 | mod generate; 9 | mod migrate; 10 | mod opt; 11 | 12 | pub use crate::opt::Opt; 13 | 14 | pub async fn run(opt: Opt) -> Result<()> { 15 | match opt.command { 16 | Command::Migrate(migrate) => match migrate.command { 17 | MigrateCommand::Add { 18 | description, 19 | reversible, 20 | } => migrate::add(&migrate.source, &description, reversible).await?, 21 | MigrateCommand::Run { 22 | dry_run, 23 | ignore_missing, 24 | database_url, 25 | } => migrate::run(&migrate.source, &database_url, dry_run, ignore_missing).await?, 26 | MigrateCommand::Revert { 27 | dry_run, 28 | ignore_missing, 29 | database_url, 30 | } => migrate::revert(&migrate.source, &database_url, dry_run, ignore_missing).await?, 31 | MigrateCommand::Info { database_url } => { 32 | migrate::info(&migrate.source, &database_url).await? 33 | } 34 | MigrateCommand::BuildScript { force } => migrate::build_script(&migrate.source, force)?, 35 | }, 36 | Command::Generate(gen_opt) => generate(gen_opt).await?, 37 | Command::Database(database) => match database.command { 38 | DatabaseCommand::Create { database_url } => database::create(&database_url).await?, 39 | DatabaseCommand::Drop { yes, database_url } => { 40 | database::drop(&database_url, !yes).await? 41 | } 42 | DatabaseCommand::Reset { 43 | yes, 44 | source, 45 | database_url, 46 | } => database::reset(&source, &database_url, !yes).await?, 47 | DatabaseCommand::Setup { 48 | source, 49 | database_url, 50 | } => database::setup(&source, &database_url).await?, 51 | }, 52 | }; 53 | 54 | Ok(()) 55 | } 56 | -------------------------------------------------------------------------------- /models-cli/src/migration.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{bail, Context}; 2 | use console::style; 3 | use std::fs::{self, File}; 4 | use std::io::{Read, Write}; 5 | 6 | const MIGRATION_FOLDER: &str = "migrations"; 7 | 8 | pub struct Migration { 9 | pub name: String, 10 | pub sql: String, 11 | } 12 | 13 | pub fn add_file(name: &str) -> anyhow::Result<()> { 14 | use chrono::prelude::*; 15 | use std::path::PathBuf; 16 | 17 | fs::create_dir_all(MIGRATION_FOLDER).context("Unable to create migrations directory")?; 18 | 19 | let dt = Utc::now(); 20 | let mut file_name = dt.format("%Y-%m-%d_%H-%M-%S").to_string(); 21 | file_name.push_str("_"); 22 | file_name.push_str(name); 23 | file_name.push_str(".sql"); 24 | 25 | let mut path = PathBuf::new(); 26 | path.push(MIGRATION_FOLDER); 27 | path.push(&file_name); 28 | 29 | let mut file = File::create(path).context("Failed to create file")?; 30 | file.write_all(b"-- Add migration script here") 31 | .context("Could not write to file")?; 32 | 33 | println!("Created migration: '{}'", file_name); 34 | Ok(()) 35 | } 36 | 37 | pub async fn run() -> anyhow::Result<()> { 38 | let migrator = crate::migrator::get()?; 39 | 40 | if !migrator.can_migrate_database() { 41 | bail!( 42 | "Database migrations not supported for {}", 43 | migrator.database_type() 44 | ); 45 | } 46 | 47 | migrator.create_migration_table().await?; 48 | 49 | let migrations = load_migrations()?; 50 | 51 | for mig in migrations.iter() { 52 | let mut tx = migrator.begin_migration().await?; 53 | 54 | if tx.check_if_applied(&mig.name).await? { 55 | println!("Already applied migration: '{}'", mig.name); 56 | continue; 57 | } 58 | println!("Applying migration: '{}'", mig.name); 59 | 60 | tx.execute_migration(&mig.sql) 61 | .await 62 | .with_context(|| format!("Failed to run migration {:?}", &mig.name))?; 63 | 64 | tx.save_applied_migration(&mig.name) 65 | .await 66 | .context("Failed to insert migration")?; 67 | 68 | tx.commit().await.context("Failed")?; 69 | } 70 | 71 | Ok(()) 72 | } 73 | 74 | pub async fn list() -> anyhow::Result<()> { 75 | let migrator = crate::migrator::get()?; 76 | 77 | if !migrator.can_migrate_database() { 78 | bail!( 79 | "Database migrations not supported for {}", 80 | migrator.database_type() 81 | ); 82 | } 83 | 84 | let file_migrations = load_migrations()?; 85 | 86 | if migrator 87 | .check_if_database_exists(&migrator.get_database_name()?) 88 | .await? 89 | { 90 | let applied_migrations = migrator.get_migrations().await.unwrap_or_else(|_| { 91 | println!("Could not retrive data from migration table"); 92 | Vec::new() 93 | }); 94 | 95 | let mut width = 0; 96 | for mig in file_migrations.iter() { 97 | width = std::cmp::max(width, mig.name.len()); 98 | } 99 | for mig in file_migrations.iter() { 100 | let status = if applied_migrations 101 | .iter() 102 | .find(|&m| mig.name == *m) 103 | .is_some() 104 | { 105 | style("Applied").green() 106 | } else { 107 | style("Not Applied").yellow() 108 | }; 109 | 110 | println!("{:width$}\t{}", mig.name, status, width = width); 111 | } 112 | 113 | let orphans = check_for_orphans(file_migrations, applied_migrations); 114 | 115 | if let Some(orphans) = orphans { 116 | println!("\nFound migrations applied in the database that does not have a corresponding migration file:"); 117 | for name in orphans { 118 | println!("{:width$}\t{}", name, style("Orphan").red(), width = width); 119 | } 120 | } 121 | } else { 122 | println!("No database found, listing migrations"); 123 | 124 | for mig in file_migrations { 125 | println!("{}", mig.name); 126 | } 127 | } 128 | 129 | Ok(()) 130 | } 131 | 132 | fn load_migrations() -> anyhow::Result> { 133 | let entries = fs::read_dir(&MIGRATION_FOLDER).context("Could not find 'migrations' dir")?; 134 | 135 | let mut migrations = Vec::new(); 136 | 137 | for e in entries { 138 | if let Ok(e) = e { 139 | if let Ok(meta) = e.metadata() { 140 | if !meta.is_file() { 141 | continue; 142 | } 143 | 144 | if let Some(ext) = e.path().extension() { 145 | if ext != "sql" { 146 | println!("Wrong ext: {:?}", ext); 147 | continue; 148 | } 149 | } else { 150 | continue; 151 | } 152 | 153 | let mut file = File::open(e.path()) 154 | .with_context(|| format!("Failed to open: '{:?}'", e.file_name()))?; 155 | let mut contents = String::new(); 156 | file.read_to_string(&mut contents) 157 | .with_context(|| format!("Failed to read: '{:?}'", e.file_name()))?; 158 | 159 | migrations.push(Migration { 160 | name: e.file_name().to_str().unwrap().to_string(), 161 | sql: contents, 162 | }); 163 | } 164 | } 165 | } 166 | 167 | migrations.sort_by(|a, b| a.name.partial_cmp(&b.name).unwrap()); 168 | 169 | Ok(migrations) 170 | } 171 | 172 | fn check_for_orphans( 173 | file_migrations: Vec, 174 | applied_migrations: Vec, 175 | ) -> Option> { 176 | let orphans: Vec = applied_migrations 177 | .iter() 178 | .filter(|m| !file_migrations.iter().any(|fm| fm.name == **m)) 179 | .cloned() 180 | .collect(); 181 | 182 | if orphans.len() > 0 { 183 | Some(orphans) 184 | } else { 185 | None 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /models-cli/src/opt.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use sqlx::migrate::{MigrateError, Migrator}; 3 | use std::path::Path; 4 | use structopt::StructOpt; 5 | #[derive(StructOpt, Debug)] 6 | pub struct Opt { 7 | #[structopt(subcommand)] 8 | pub command: Command, 9 | } 10 | 11 | #[derive(StructOpt, Debug)] 12 | pub enum Command { 13 | #[structopt(alias = "db")] 14 | Database(DatabaseOpt), 15 | 16 | #[structopt(alias = "mig")] 17 | Migrate(MigrateOpt), 18 | 19 | #[structopt(alias = "gen")] 20 | Generate(GenerateOpt), 21 | } 22 | 23 | /// Group of commands for creating and dropping your database. 24 | #[derive(StructOpt, Debug)] 25 | pub struct DatabaseOpt { 26 | #[structopt(subcommand)] 27 | pub command: DatabaseCommand, 28 | } 29 | 30 | #[derive(StructOpt, Debug)] 31 | pub enum DatabaseCommand { 32 | /// Creates the database specified in your DATABASE_URL. 33 | Create { 34 | /// Location of the DB, by default will be read from the DATABASE_URL env var 35 | #[structopt(long, short = "D", env)] 36 | database_url: String, 37 | }, 38 | 39 | /// Drops the database specified in your DATABASE_URL. 40 | Drop { 41 | /// Automatic confirmation. Without this option, you will be prompted before dropping 42 | /// your database. 43 | #[structopt(short)] 44 | yes: bool, 45 | 46 | /// Location of the DB, by default will be read from the DATABASE_URL env var 47 | #[structopt(long, short = "D", env)] 48 | database_url: String, 49 | }, 50 | 51 | /// Drops the database specified in your DATABASE_URL, re-creates it, and runs any pending migrations. 52 | Reset { 53 | /// Automatic confirmation. Without this option, you will be prompted before dropping 54 | /// your database. 55 | #[structopt(short)] 56 | yes: bool, 57 | 58 | /// Path to folder containing migrations. 59 | #[structopt(long, default_value = "migrations")] 60 | source: String, 61 | 62 | /// Location of the DB, by default will be read from the DATABASE_URL env var 63 | #[structopt(long, short = "D", env)] 64 | database_url: String, 65 | }, 66 | 67 | /// Creates the database specified in your DATABASE_URL and runs any pending migrations. 68 | Setup { 69 | /// Path to folder containing migrations. 70 | #[structopt(long, default_value = "migrations")] 71 | source: String, 72 | 73 | /// Location of the DB, by default will be read from the DATABASE_URL env var 74 | #[structopt(long, short = "D", env)] 75 | database_url: String, 76 | }, 77 | } 78 | 79 | /// Group of commands for creating and running migrations. 80 | #[derive(StructOpt, Debug)] 81 | pub struct MigrateOpt { 82 | /// Path to folder containing migrations. 83 | #[structopt(long, default_value = "migrations")] 84 | pub source: String, 85 | 86 | #[structopt(subcommand)] 87 | pub command: MigrateCommand, 88 | } 89 | /// Commands related to automatic migration generation. 90 | #[derive(StructOpt, Debug)] 91 | pub struct GenerateOpt { 92 | /// Location of the DB, by default will be read from the DATABASE_URL env var 93 | #[structopt(long, short = "D", env)] 94 | pub database_url: String, 95 | /// Path to folder containing migrations. 96 | #[structopt(long, default_value = "migrations")] 97 | pub source: String, 98 | /// Used to filter through the models to execute. 99 | #[structopt(long)] 100 | pub table: Option, 101 | /// Used to generate a down migrations along with up migrations. 102 | #[structopt(short)] 103 | pub reversible: bool, 104 | } 105 | 106 | impl GenerateOpt { 107 | pub async fn validate(&self) -> Result<()> { 108 | url::Url::parse(&self.database_url)?; 109 | let migrator = Migrator::new(Path::new(&self.source)).await?; 110 | for migration in migrator.iter() { 111 | if migration.migration_type.is_reversible() != self.reversible { 112 | Err(MigrateError::InvalidMixReversibleAndSimple)? 113 | } 114 | } 115 | 116 | Ok(()) 117 | } 118 | } 119 | 120 | #[derive(StructOpt, Debug)] 121 | pub enum MigrateCommand { 122 | /// Create a new migration with the given description, 123 | /// and the current time as the version. 124 | Add { 125 | description: String, 126 | 127 | /// If true, creates a pair of up and down migration files with same version 128 | /// else creates a single sql file 129 | #[structopt(short)] 130 | reversible: bool, 131 | }, 132 | 133 | /// Run all pending migrations. 134 | Run { 135 | /// List all the migrations to be run without applying 136 | #[structopt(long)] 137 | dry_run: bool, 138 | 139 | /// Ignore applied migrations that missing in the resolved migrations 140 | #[structopt(long)] 141 | ignore_missing: bool, 142 | 143 | /// Location of the DB, by default will be read from the DATABASE_URL env var 144 | #[structopt(long, short = "D", env)] 145 | database_url: String, 146 | }, 147 | 148 | /// Revert the latest migration with a down file. 149 | Revert { 150 | /// List the migration to be reverted without applying 151 | #[structopt(long)] 152 | dry_run: bool, 153 | 154 | /// Ignore applied migrations that missing in the resolved migrations 155 | #[structopt(long)] 156 | ignore_missing: bool, 157 | 158 | /// Location of the DB, by default will be read from the DATABASE_URL env var 159 | #[structopt(long, short = "D", env)] 160 | database_url: String, 161 | }, 162 | 163 | /// List all available migrations. 164 | Info { 165 | /// Location of the DB, by default will be read from the DATABASE_URL env var 166 | #[structopt(long, env)] 167 | database_url: String, 168 | }, 169 | 170 | /// Generate a `build.rs` to trigger recompilation when a new migration is added. 171 | /// 172 | /// Must be run in a Cargo project root. 173 | BuildScript { 174 | /// Overwrite the build script if it already exists. 175 | #[structopt(long)] 176 | force: bool, 177 | }, 178 | } 179 | -------------------------------------------------------------------------------- /models-parser/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "models-parser" 3 | description = "Helper crate for models" 4 | version = "0.2.0" 5 | authors = ["Andy Grove "] 6 | repository = "https://github.com/tvallotton/sqlx-models" 7 | license = "Apache-2.0" 8 | edition = "2018" 9 | 10 | 11 | 12 | [features] 13 | default = ["std"] 14 | std = [] 15 | # Enable JSON output in the `cli` example: 16 | json_example = ["serde_json", "serde"] 17 | 18 | [dependencies] 19 | bigdecimal = { version = "0.3", features = ["serde"], optional = true } 20 | log = "0.4" 21 | serde = { version = "1.0", features = ["derive"], optional = true } 22 | # serde_json is only used in examples/cli, but we have to put it outside 23 | # of dev-dependencies because of 24 | # https://github.com/rust-lang/cargo/issues/1596 25 | serde_json = { version = "1.0", optional = true } 26 | 27 | [dev-dependencies] 28 | simple_logger = "1.9" 29 | matches = "0.1" 30 | 31 | [package.metadata.release] 32 | # Instruct `cargo release` to not run `cargo publish` locally: 33 | # https://github.com/sunng87/cargo-release/blob/master/docs/reference.md#config-fields 34 | # See docs/releasing.md for details. 35 | disable-publish = true 36 | -------------------------------------------------------------------------------- /models-parser/HEADER: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. -------------------------------------------------------------------------------- /models-parser/README.md: -------------------------------------------------------------------------------- 1 | This is a helper crate for `models`. Don't directly depend on it. -------------------------------------------------------------------------------- /models-parser/src/ast/data_type.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #[cfg(not(feature = "std"))] 14 | use alloc::boxed::Box; 15 | use core::fmt; 16 | 17 | #[cfg(feature = "serde")] 18 | use serde::{Deserialize, Serialize}; 19 | 20 | use crate::ast::ObjectName; 21 | 22 | /// SQL data types 23 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 24 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 25 | pub enum DataType { 26 | /// Fixed-length character type e.g. CHAR(10) 27 | Char(Option), 28 | /// Variable-length character type e.g. VARCHAR(10) 29 | Varchar(Option), 30 | /// Uuid type 31 | Uuid, 32 | /// Large character object e.g. CLOB(1000) 33 | Clob(u64), 34 | /// Fixed-length binary type e.g. BINARY(10) 35 | Binary(u64), 36 | /// Variable-length binary type e.g. VARBINARY(10) 37 | Varbinary(Option), 38 | /// Large binary object e.g. BLOB(1000) 39 | Blob(Option), 40 | /// Decimal type with optional precision and scale e.g. DECIMAL(10,2) 41 | Decimal(Option, Option), 42 | /// Floating point with optional precision e.g. FLOAT(8) 43 | Float(Option), 44 | /// Tiny integer with optional display width e.g. TINYINT or TINYINT(3) 45 | TinyInt(Option), 46 | /// Small integer with optional display width e.g. SMALLINT or SMALLINT(5) 47 | SmallInt(Option), 48 | /// INT with optional display width e.g. INT or INT(11) 49 | Int(Option), 50 | /// Big integer with optional display width e.g. BIGINT or BIGINT(20) 51 | BigInt(Option), 52 | /// Floating point e.g. REAL 53 | Real, 54 | /// Double e.g. DOUBLE PRECISION 55 | Double, 56 | /// Boolean 57 | Boolean, 58 | /// Date 59 | Date, 60 | /// Time 61 | Time, 62 | /// Timestamp 63 | Timestamp, 64 | /// Interval 65 | Interval, 66 | /// Regclass used in postgresql serial 67 | Regclass, 68 | /// Text 69 | Text, 70 | /// String 71 | String, 72 | /// Bytea 73 | Bytea, 74 | /// Custom type such as enums 75 | Custom(ObjectName), 76 | /// Arrays 77 | Array(Box), 78 | /// JSON 79 | Json, 80 | /// Serial PostgeSQL type 81 | Serial, 82 | } 83 | 84 | impl fmt::Display for DataType { 85 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 86 | match self { 87 | DataType::Serial => write!(f, "SERIAL"), 88 | DataType::Json => write!(f, "JSON"), 89 | DataType::Char(size) => format_type_with_optional_length(f, "CHAR", size), 90 | DataType::Varchar(size) => format_type_with_optional_length(f, "VARCHAR", size), 91 | DataType::Uuid => write!(f, "UUID"), 92 | DataType::Clob(size) => write!(f, "CLOB({})", size), 93 | DataType::Binary(size) => write!(f, "BINARY({})", size), 94 | DataType::Varbinary(size) => format_type_with_optional_length(f, "VARBINARY", size), 95 | DataType::Blob(size) => { 96 | if let Some(size) = size { 97 | write!(f, "BLOB({})", size) 98 | } else { 99 | write!(f, "BLOB") 100 | } 101 | } 102 | DataType::Decimal(precision, scale) => { 103 | if let Some(scale) = scale { 104 | write!(f, "NUMERIC({},{})", precision.unwrap(), scale) 105 | } else { 106 | format_type_with_optional_length(f, "NUMERIC", precision) 107 | } 108 | } 109 | DataType::Float(size) => format_type_with_optional_length(f, "FLOAT", size), 110 | DataType::TinyInt(zerofill) => format_type_with_optional_length(f, "TINYINT", zerofill), 111 | DataType::SmallInt(zerofill) => { 112 | format_type_with_optional_length(f, "SMALLINT", zerofill) 113 | } 114 | DataType::Int(zerofill) => { 115 | if let Some(len) = zerofill { 116 | write!(f, "INT({})", len) 117 | } else { 118 | write!(f, "INTEGER") 119 | } 120 | } 121 | DataType::BigInt(zerofill) => format_type_with_optional_length(f, "BIGINT", zerofill), 122 | DataType::Real => write!(f, "REAL"), 123 | DataType::Double => write!(f, "DOUBLE PRECISION"), 124 | DataType::Boolean => write!(f, "BOOLEAN"), 125 | DataType::Date => write!(f, "DATE"), 126 | DataType::Time => write!(f, "TIME"), 127 | DataType::Timestamp => write!(f, "TIMESTAMP"), 128 | DataType::Interval => write!(f, "INTERVAL"), 129 | DataType::Regclass => write!(f, "REGCLASS"), 130 | DataType::Text => write!(f, "TEXT"), 131 | DataType::String => write!(f, "STRING"), 132 | DataType::Bytea => write!(f, "BYTEA"), 133 | DataType::Array(ty) => write!(f, "{}[]", ty), 134 | DataType::Custom(ty) => write!(f, "{}", ty), 135 | } 136 | } 137 | } 138 | impl DataType { 139 | pub fn custom(custom: &str) -> Self { 140 | Self::Custom(ObjectName(vec![super::Ident::new(custom)])) 141 | } 142 | } 143 | 144 | fn format_type_with_optional_length( 145 | f: &mut fmt::Formatter, 146 | sql_type: &'static str, 147 | len: &Option, 148 | ) -> fmt::Result { 149 | write!(f, "{}", sql_type)?; 150 | if let Some(len) = len { 151 | write!(f, "({})", len)?; 152 | } 153 | Ok(()) 154 | } 155 | -------------------------------------------------------------------------------- /models-parser/src/ast/expression/display.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | 3 | impl fmt::Display for Expr { 4 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 5 | match self { 6 | Expr::Identifier(s) => write!(f, "{}", s), 7 | Expr::MapAccess(x) => write!(f, "{}", x), 8 | Expr::Wildcard => f.write_str("*"), 9 | Expr::QualifiedWildcard(q) => write!(f, "{}.*", display_separated(q, ".")), 10 | Expr::CompoundIdentifier(s) => write!(f, "{}", display_separated(s, ".")), 11 | Expr::IsNull(ast) => write!(f, "{} IS NULL", ast), 12 | Expr::IsNotNull(ast) => write!(f, "{} IS NOT NULL", ast), 13 | Expr::InList(x) => write!(f, "{}", x), 14 | Expr::InSubquery(x) => write!(f, "{}", x), 15 | Expr::Between(x) => write!(f, "{}", x), 16 | Expr::BinaryOp(x) => write!(f, "{}", x), 17 | Expr::UnaryOp(x) => write!(f, "{}", x), 18 | Expr::Cast(x) => write!(f, "{}", x), 19 | Expr::TryCast(x) => write!(f, "{}", x), 20 | Expr::Extract(x) => write!(f, "{}", x), 21 | Expr::Collate(x) => write!(f, "{}", x), 22 | Expr::Nested(ast) => write!(f, "({})", ast), 23 | Expr::Value(v) => write!(f, "{}", v), 24 | Expr::TypedString(x) => write!(f, "{}", x), 25 | Expr::Function(fun) => write!(f, "{}", fun), 26 | Expr::Case(x) => write!(f, "{}", x), 27 | Expr::Exists(s) => write!(f, "EXISTS ({})", s), 28 | Expr::Subquery(s) => write!(f, "({})", s), 29 | Expr::ListAgg(listagg) => write!(f, "{}", listagg), 30 | Expr::Substring(x) => write!(f, "{}", x), 31 | Expr::Trim(x) => write!(f, "{}", x), 32 | } 33 | } 34 | } 35 | 36 | impl fmt::Display for Trim { 37 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 38 | write!(f, "TRIM(")?; 39 | if let Some((ref ident, ref trim_char)) = self.trim_where { 40 | write!(f, "{} {} FROM {}", ident, trim_char, self.expr)?; 41 | } else { 42 | write!(f, "{}", self.expr)?; 43 | } 44 | write!(f, ")") 45 | } 46 | } 47 | 48 | impl fmt::Display for Substring { 49 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 50 | write!(f, "SUBSTRING({}", self.expr)?; 51 | if let Some(ref from_part) = self.substring_from { 52 | write!(f, " FROM {}", from_part)?; 53 | } 54 | if let Some(ref from_part) = self.substring_for { 55 | write!(f, " FOR {}", from_part)?; 56 | } 57 | 58 | write!(f, ")") 59 | } 60 | } 61 | 62 | impl fmt::Display for Case { 63 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 64 | write!(f, "CASE")?; 65 | if let Some(ref operand) = self.operand { 66 | write!(f, " {}", operand)?; 67 | } 68 | for (c, r) in self.conditions.iter().zip(&self.results) { 69 | write!(f, " WHEN {} THEN {}", c, r)?; 70 | } 71 | 72 | if let Some(ref else_result) = self.else_result { 73 | write!(f, " ELSE {}", else_result)?; 74 | } 75 | write!(f, " END") 76 | } 77 | } 78 | 79 | impl fmt::Display for TypedString { 80 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 81 | write!(f, "{}", self.data_type)?; 82 | write!(f, " '{}'", &value::escape_single_quote_string(&self.value)) 83 | } 84 | } 85 | 86 | impl fmt::Display for Collate { 87 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 88 | write!(f, "{} COLLATE {}", self.expr, self.collation) 89 | } 90 | } 91 | impl fmt::Display for UnaryOp { 92 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 93 | if self.op == UnaryOperator::PGPostfixFactorial { 94 | write!(f, "{}{}", self.expr, self.op) 95 | } else { 96 | write!(f, "{} {}", self.op, self.expr) 97 | } 98 | } 99 | } 100 | impl fmt::Display for InSubquery { 101 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 102 | write!( 103 | f, 104 | "{} {}IN ({})", 105 | self.expr, 106 | if self.negated { "NOT " } else { "" }, 107 | self.subquery 108 | ) 109 | } 110 | } 111 | impl fmt::Display for Between { 112 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 113 | write!( 114 | f, 115 | "{} {}BETWEEN {} AND {}", 116 | self.expr, 117 | if self.negated { "NOT " } else { "" }, 118 | self.low, 119 | self.high 120 | ) 121 | } 122 | } 123 | impl fmt::Display for InList { 124 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 125 | write!( 126 | f, 127 | "{} {}IN ({})", 128 | self.expr, 129 | if self.negated { "NOT " } else { "" }, 130 | display_comma_separated(&self.list) 131 | ) 132 | } 133 | } 134 | impl fmt::Display for BinaryOp { 135 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 136 | write!(f, "{} {} {}", self.left, self.op, self.right) 137 | } 138 | } 139 | impl fmt::Display for MapAccess { 140 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 141 | write!(f, "{}[\"{}\"]", self.column, self.key) 142 | } 143 | } 144 | 145 | impl fmt::Display for Cast { 146 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 147 | write!(f, "CAST({} AS {})", self.expr, self.data_type) 148 | } 149 | } 150 | 151 | impl fmt::Display for TryCast { 152 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 153 | write!(f, "TRY_CAST({} AS {})", self.expr, self.data_type) 154 | } 155 | } 156 | 157 | impl fmt::Display for Extract { 158 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 159 | write!(f, "EXTRACT({} FROM {})", self.field, self.expr) 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /models-parser/src/ast/operator.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use core::fmt; 14 | 15 | #[cfg(feature = "serde")] 16 | use serde::{Deserialize, Serialize}; 17 | 18 | /// Unary operators 19 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 20 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 21 | pub enum UnaryOperator { 22 | Plus, 23 | Minus, 24 | Not, 25 | /// Bitwise Not, e.g. `~9` (PostgreSQL-specific) 26 | PGBitwiseNot, 27 | /// Square root, e.g. `|/9` (PostgreSQL-specific) 28 | PGSquareRoot, 29 | /// Cube root, e.g. `||/27` (PostgreSQL-specific) 30 | PGCubeRoot, 31 | /// Factorial, e.g. `9!` (PostgreSQL-specific) 32 | PGPostfixFactorial, 33 | /// Factorial, e.g. `!!9` (PostgreSQL-specific) 34 | PGPrefixFactorial, 35 | /// Absolute value, e.g. `@ -9` (PostgreSQL-specific) 36 | PGAbs, 37 | } 38 | 39 | impl fmt::Display for UnaryOperator { 40 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 41 | f.write_str(match self { 42 | UnaryOperator::Plus => "+", 43 | UnaryOperator::Minus => "-", 44 | UnaryOperator::Not => "NOT", 45 | UnaryOperator::PGBitwiseNot => "~", 46 | UnaryOperator::PGSquareRoot => "|/", 47 | UnaryOperator::PGCubeRoot => "||/", 48 | UnaryOperator::PGPostfixFactorial => "!", 49 | UnaryOperator::PGPrefixFactorial => "!!", 50 | UnaryOperator::PGAbs => "@", 51 | }) 52 | } 53 | } 54 | 55 | /// Binary operators 56 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 57 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 58 | pub enum BinaryOperator { 59 | Plus, 60 | Minus, 61 | Multiply, 62 | Divide, 63 | Modulo, 64 | StringConcat, 65 | Gt, 66 | Lt, 67 | GtEq, 68 | LtEq, 69 | Spaceship, 70 | Eq, 71 | NotEq, 72 | And, 73 | Or, 74 | Like, 75 | NotLike, 76 | ILike, 77 | NotILike, 78 | BitwiseOr, 79 | BitwiseAnd, 80 | BitwiseXor, 81 | PGBitwiseXor, 82 | PGBitwiseShiftLeft, 83 | PGBitwiseShiftRight, 84 | PGRegexMatch, 85 | PGRegexIMatch, 86 | PGRegexNotMatch, 87 | PGRegexNotIMatch, 88 | } 89 | 90 | impl fmt::Display for BinaryOperator { 91 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 92 | f.write_str(match self { 93 | BinaryOperator::Plus => "+", 94 | BinaryOperator::Minus => "-", 95 | BinaryOperator::Multiply => "*", 96 | BinaryOperator::Divide => "/", 97 | BinaryOperator::Modulo => "%", 98 | BinaryOperator::StringConcat => "||", 99 | BinaryOperator::Gt => ">", 100 | BinaryOperator::Lt => "<", 101 | BinaryOperator::GtEq => ">=", 102 | BinaryOperator::LtEq => "<=", 103 | BinaryOperator::Spaceship => "<=>", 104 | BinaryOperator::Eq => "=", 105 | BinaryOperator::NotEq => "<>", 106 | BinaryOperator::And => "AND", 107 | BinaryOperator::Or => "OR", 108 | BinaryOperator::Like => "LIKE", 109 | BinaryOperator::NotLike => "NOT LIKE", 110 | BinaryOperator::ILike => "ILIKE", 111 | BinaryOperator::NotILike => "NOT ILIKE", 112 | BinaryOperator::BitwiseOr => "|", 113 | BinaryOperator::BitwiseAnd => "&", 114 | BinaryOperator::BitwiseXor => "^", 115 | BinaryOperator::PGBitwiseXor => "#", 116 | BinaryOperator::PGBitwiseShiftLeft => "<<", 117 | BinaryOperator::PGBitwiseShiftRight => ">>", 118 | BinaryOperator::PGRegexMatch => "~", 119 | BinaryOperator::PGRegexIMatch => "~*", 120 | BinaryOperator::PGRegexNotMatch => "!~", 121 | BinaryOperator::PGRegexNotIMatch => "!~*", 122 | }) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /models-parser/src/ast/value.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #[cfg(not(feature = "std"))] 14 | use alloc::string::String; 15 | use core::fmt; 16 | 17 | #[cfg(feature = "bigdecimal")] 18 | use bigdecimal::BigDecimal; 19 | #[cfg(feature = "serde")] 20 | use serde::{Deserialize, Serialize}; 21 | 22 | /// Primitive SQL values such as number and string 23 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 24 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 25 | pub enum Value { 26 | /// Numeric literal 27 | #[cfg(not(feature = "bigdecimal"))] 28 | Number(String, bool), 29 | #[cfg(feature = "bigdecimal")] 30 | Number(BigDecimal, bool), 31 | /// 'string value' 32 | SingleQuotedString(String), 33 | /// N'string value' 34 | NationalStringLiteral(String), 35 | /// X'hex value' 36 | HexStringLiteral(String), 37 | 38 | DoubleQuotedString(String), 39 | /// Boolean value true or false 40 | Boolean(bool), 41 | /// INTERVAL literals, roughly in the following format: 42 | /// `INTERVAL '' [ [ () ] ] 43 | /// [ TO [ () ] ]`, 44 | /// e.g. `INTERVAL '123:45.67' MINUTE(3) TO SECOND(2)`. 45 | /// 46 | /// The parser does not validate the ``, nor does it ensure 47 | /// that the `` units >= the units in ``, 48 | /// so the user will have to reject intervals like `HOUR TO YEAR`. 49 | Interval { 50 | value: String, 51 | leading_field: Option, 52 | leading_precision: Option, 53 | last_field: Option, 54 | /// The seconds precision can be specified in SQL source as 55 | /// `INTERVAL '__' SECOND(_, x)` (in which case the `leading_field` 56 | /// will be `Second` and the `last_field` will be `None`), 57 | /// or as `__ TO SECOND(x)`. 58 | fractional_seconds_precision: Option, 59 | }, 60 | /// `NULL` value 61 | Null, 62 | } 63 | 64 | impl fmt::Display for Value { 65 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 66 | match self { 67 | Value::Number(v, l) => write!(f, "{}{long}", v, long = if *l { "L" } else { "" }), 68 | Value::DoubleQuotedString(v) => write!(f, "\"{}\"", v), 69 | Value::SingleQuotedString(v) => write!(f, "'{}'", escape_single_quote_string(v)), 70 | Value::NationalStringLiteral(v) => write!(f, "N'{}'", v), 71 | Value::HexStringLiteral(v) => write!(f, "X'{}'", v), 72 | Value::Boolean(v) => write!(f, "{}", v), 73 | Value::Interval { 74 | value, 75 | leading_field: Some(DateTimeField::Second), 76 | leading_precision: Some(leading_precision), 77 | last_field, 78 | fractional_seconds_precision: Some(fractional_seconds_precision), 79 | } => { 80 | // When the leading field is SECOND, the parser guarantees that 81 | // the last field is None. 82 | assert!(last_field.is_none()); 83 | write!( 84 | f, 85 | "INTERVAL '{}' SECOND ({}, {})", 86 | escape_single_quote_string(value), 87 | leading_precision, 88 | fractional_seconds_precision 89 | ) 90 | } 91 | Value::Interval { 92 | value, 93 | leading_field, 94 | leading_precision, 95 | last_field, 96 | fractional_seconds_precision, 97 | } => { 98 | write!(f, "INTERVAL '{}'", escape_single_quote_string(value))?; 99 | if let Some(leading_field) = leading_field { 100 | write!(f, " {}", leading_field)?; 101 | } 102 | if let Some(leading_precision) = leading_precision { 103 | write!(f, " ({})", leading_precision)?; 104 | } 105 | if let Some(last_field) = last_field { 106 | write!(f, " TO {}", last_field)?; 107 | } 108 | if let Some(fractional_seconds_precision) = fractional_seconds_precision { 109 | write!(f, " ({})", fractional_seconds_precision)?; 110 | } 111 | Ok(()) 112 | } 113 | Value::Null => write!(f, "NULL"), 114 | } 115 | } 116 | } 117 | 118 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 119 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 120 | pub enum DateTimeField { 121 | Year, 122 | Month, 123 | Day, 124 | Hour, 125 | Minute, 126 | Second, 127 | } 128 | 129 | impl fmt::Display for DateTimeField { 130 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 131 | f.write_str(match self { 132 | DateTimeField::Year => "YEAR", 133 | DateTimeField::Month => "MONTH", 134 | DateTimeField::Day => "DAY", 135 | DateTimeField::Hour => "HOUR", 136 | DateTimeField::Minute => "MINUTE", 137 | DateTimeField::Second => "SECOND", 138 | }) 139 | } 140 | } 141 | 142 | pub struct EscapeSingleQuoteString<'a>(&'a str); 143 | 144 | impl<'a> fmt::Display for EscapeSingleQuoteString<'a> { 145 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 146 | for c in self.0.chars() { 147 | if c == '\'' { 148 | write!(f, "\'\'")?; 149 | } else { 150 | write!(f, "{}", c)?; 151 | } 152 | } 153 | Ok(()) 154 | } 155 | } 156 | 157 | pub fn escape_single_quote_string(s: &str) -> EscapeSingleQuoteString<'_> { 158 | EscapeSingleQuoteString(s) 159 | } 160 | 161 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 162 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 163 | pub enum TrimWhereField { 164 | Both, 165 | Leading, 166 | Trailing, 167 | } 168 | 169 | impl fmt::Display for TrimWhereField { 170 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 171 | use TrimWhereField::*; 172 | f.write_str(match self { 173 | Both => "BOTH", 174 | Leading => "LEADING", 175 | Trailing => "TRAILING", 176 | }) 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /models-parser/src/dialect/ansi.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct AnsiDialect {} 17 | 18 | impl Dialect for AnsiDialect { 19 | fn is_identifier_start(&self, ch: char) -> bool { 20 | ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) 21 | } 22 | 23 | fn is_identifier_part(&self, ch: char) -> bool { 24 | ('a'..='z').contains(&ch) 25 | || ('A'..='Z').contains(&ch) 26 | || ('0'..='9').contains(&ch) 27 | || ch == '_' 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /models-parser/src/dialect/generic.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug, Default)] 16 | pub struct GenericDialect; 17 | 18 | impl Dialect for GenericDialect { 19 | fn is_identifier_start(&self, ch: char) -> bool { 20 | ('a'..='z').contains(&ch) 21 | || ('A'..='Z').contains(&ch) 22 | || ch == '_' 23 | || ch == '#' 24 | || ch == '@' 25 | } 26 | 27 | fn is_identifier_part(&self, ch: char) -> bool { 28 | ('a'..='z').contains(&ch) 29 | || ('A'..='Z').contains(&ch) 30 | || ('0'..='9').contains(&ch) 31 | || ch == '@' 32 | || ch == '$' 33 | || ch == '#' 34 | || ch == '_' 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /models-parser/src/dialect/hive.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct HiveDialect {} 17 | 18 | impl Dialect for HiveDialect { 19 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 20 | (ch == '"') || (ch == '`') 21 | } 22 | 23 | fn is_identifier_start(&self, ch: char) -> bool { 24 | ('a'..='z').contains(&ch) 25 | || ('A'..='Z').contains(&ch) 26 | || ('0'..='9').contains(&ch) 27 | || ch == '$' 28 | } 29 | 30 | fn is_identifier_part(&self, ch: char) -> bool { 31 | ('a'..='z').contains(&ch) 32 | || ('A'..='Z').contains(&ch) 33 | || ('0'..='9').contains(&ch) 34 | || ch == '_' 35 | || ch == '$' 36 | || ch == '{' 37 | || ch == '}' 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /models-parser/src/dialect/mod.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | mod ansi; 14 | mod generic; 15 | mod hive; 16 | pub mod keywords; 17 | mod mssql; 18 | mod mysql; 19 | mod postgresql; 20 | mod snowflake; 21 | mod sqlite; 22 | 23 | use core::any::{Any, TypeId}; 24 | use core::fmt::Debug; 25 | 26 | pub use self::ansi::AnsiDialect; 27 | pub use self::generic::GenericDialect; 28 | pub use self::hive::HiveDialect; 29 | pub use self::mssql::MsSqlDialect; 30 | pub use self::mysql::MySqlDialect; 31 | pub use self::postgresql::PostgreSqlDialect; 32 | pub use self::snowflake::SnowflakeDialect; 33 | pub use self::sqlite::SQLiteDialect; 34 | 35 | /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates 36 | /// to `true` iff `parser.dialect` is one of the `Dialect`s specified. 37 | macro_rules! dialect_of { 38 | ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { 39 | ($($parsed_dialect.dialect.is::<$dialect_type>())||+) 40 | }; 41 | } 42 | 43 | pub trait Dialect: Debug + Any { 44 | /// Determine if a character starts a quoted identifier. The default 45 | /// implementation, accepting "double quoted" ids is both ANSI-compliant 46 | /// and appropriate for most dialects (with the notable exception of 47 | /// MySQL, MS SQL, and sqlite). You can accept one of characters listed 48 | /// in `Word::matching_end_quote` here 49 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 50 | ch == '"' 51 | } 52 | /// Determine if a character is a valid start character for an unquoted identifier 53 | fn is_identifier_start(&self, ch: char) -> bool; 54 | /// Determine if a character is a valid unquoted identifier character 55 | fn is_identifier_part(&self, ch: char) -> bool; 56 | } 57 | 58 | impl dyn Dialect { 59 | #[inline] 60 | pub fn is(&self) -> bool { 61 | // borrowed from `Any` implementation 62 | TypeId::of::() == self.type_id() 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use super::ansi::AnsiDialect; 69 | use super::generic::GenericDialect; 70 | use super::*; 71 | 72 | struct DialectHolder<'a> { 73 | dialect: &'a dyn Dialect, 74 | } 75 | 76 | #[test] 77 | fn test_is_dialect() { 78 | let generic_dialect: &dyn Dialect = &GenericDialect {}; 79 | let ansi_dialect: &dyn Dialect = &AnsiDialect {}; 80 | 81 | let generic_holder = DialectHolder { 82 | dialect: generic_dialect, 83 | }; 84 | let ansi_holder = DialectHolder { 85 | dialect: ansi_dialect, 86 | }; 87 | 88 | assert!(dialect_of!(generic_holder is GenericDialect | AnsiDialect),); 89 | assert!(!dialect_of!(generic_holder is AnsiDialect)); 90 | 91 | assert!(dialect_of!(ansi_holder is AnsiDialect)); 92 | assert!(dialect_of!(ansi_holder is GenericDialect | AnsiDialect),); 93 | assert!(!dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /models-parser/src/dialect/mssql.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct MsSqlDialect {} 17 | 18 | impl Dialect for MsSqlDialect { 19 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 20 | ch == '"' || ch == '[' 21 | } 22 | 23 | fn is_identifier_start(&self, ch: char) -> bool { 24 | // See https://docs.microsoft.com/en-us/sql/relational-databases/databases/database-identifiers?view=sql-server-2017#rules-for-regular-identifiers 25 | // We don't support non-latin "letters" currently. 26 | ('a'..='z').contains(&ch) 27 | || ('A'..='Z').contains(&ch) 28 | || ch == '_' 29 | || ch == '#' 30 | || ch == '@' 31 | } 32 | 33 | fn is_identifier_part(&self, ch: char) -> bool { 34 | ('a'..='z').contains(&ch) 35 | || ('A'..='Z').contains(&ch) 36 | || ('0'..='9').contains(&ch) 37 | || ch == '@' 38 | || ch == '$' 39 | || ch == '#' 40 | || ch == '_' 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /models-parser/src/dialect/mysql.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct MySqlDialect {} 17 | 18 | impl Dialect for MySqlDialect { 19 | fn is_identifier_start(&self, ch: char) -> bool { 20 | // See https://dev.mysql.com/doc/refman/8.0/en/identifiers.html. 21 | // We don't yet support identifiers beginning with numbers, as that 22 | // makes it hard to distinguish numeric literals. 23 | ('a'..='z').contains(&ch) 24 | || ('A'..='Z').contains(&ch) 25 | || ch == '_' 26 | || ch == '$' 27 | || ('\u{0080}'..='\u{ffff}').contains(&ch) 28 | } 29 | 30 | fn is_identifier_part(&self, ch: char) -> bool { 31 | self.is_identifier_start(ch) || ('0'..='9').contains(&ch) 32 | } 33 | 34 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 35 | ch == '`' 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /models-parser/src/dialect/postgresql.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct PostgreSqlDialect {} 17 | 18 | impl Dialect for PostgreSqlDialect { 19 | fn is_identifier_start(&self, ch: char) -> bool { 20 | // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS 21 | // We don't yet support identifiers beginning with "letters with 22 | // diacritical marks and non-Latin letters" 23 | ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' 24 | } 25 | 26 | fn is_identifier_part(&self, ch: char) -> bool { 27 | ('a'..='z').contains(&ch) 28 | || ('A'..='Z').contains(&ch) 29 | || ('0'..='9').contains(&ch) 30 | || ch == '$' 31 | || ch == '_' 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /models-parser/src/dialect/snowflake.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug, Default)] 16 | pub struct SnowflakeDialect; 17 | 18 | impl Dialect for SnowflakeDialect { 19 | // see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html 20 | fn is_identifier_start(&self, ch: char) -> bool { 21 | ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' 22 | } 23 | 24 | fn is_identifier_part(&self, ch: char) -> bool { 25 | ('a'..='z').contains(&ch) 26 | || ('A'..='Z').contains(&ch) 27 | || ('0'..='9').contains(&ch) 28 | || ch == '$' 29 | || ch == '_' 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /models-parser/src/dialect/sqlite.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | use crate::dialect::Dialect; 14 | 15 | #[derive(Debug)] 16 | pub struct SQLiteDialect {} 17 | 18 | impl Dialect for SQLiteDialect { 19 | // see https://www.sqlite.org/lang_keywords.html 20 | // parse `...`, [...] and "..." as identifier 21 | // TODO: support depending on the context tread '...' as identifier too. 22 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 23 | ch == '`' || ch == '"' || ch == '[' 24 | } 25 | 26 | fn is_identifier_start(&self, ch: char) -> bool { 27 | // See https://www.sqlite.org/draft/tokenreq.html 28 | ('a'..='z').contains(&ch) 29 | || ('A'..='Z').contains(&ch) 30 | || ch == '_' 31 | || ch == '$' 32 | || ('\u{007f}'..='\u{ffff}').contains(&ch) 33 | } 34 | 35 | fn is_identifier_part(&self, ch: char) -> bool { 36 | self.is_identifier_start(ch) || ('0'..='9').contains(&ch) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models-parser/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(feature = "std"), no_std)] 2 | #![allow(clippy::upper_case_acronyms)] 3 | 4 | #[cfg(not(feature = "std"))] 5 | extern crate alloc; 6 | 7 | pub mod ast; 8 | #[macro_use] 9 | pub mod dialect; 10 | pub mod parser; 11 | pub mod tokenizer; 12 | 13 | #[doc(hidden)] 14 | // This is required to make utilities accessible by both the crate-internal 15 | // unit-tests and by the integration tests 16 | // External users are not supposed to rely on this module. 17 | pub mod test_utils; 18 | -------------------------------------------------------------------------------- /models-parser/src/test_utils.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | /// This module contains internal utilities used for testing the library. 14 | /// While technically public, the library's users are not supposed to rely 15 | /// on this module, as it will change without notice. 16 | // 17 | // Integration tests (i.e. everything under `tests/`) import this 18 | // via `tests/test_utils/mod.rs`. 19 | 20 | #[cfg(not(feature = "std"))] 21 | use alloc::{ 22 | boxed::Box, 23 | string::{String, ToString}, 24 | vec, 25 | vec::Vec, 26 | }; 27 | use core::fmt::Debug; 28 | 29 | use crate::ast::*; 30 | use crate::dialect::*; 31 | use crate::parser::{Parser, ParserError}; 32 | use crate::tokenizer::Tokenizer; 33 | 34 | /// Tests use the methods on this struct to invoke the parser on one or 35 | /// multiple dialects. 36 | pub struct TestedDialects { 37 | pub dialects: Vec>, 38 | } 39 | 40 | impl TestedDialects { 41 | /// Run the given function for all of `self.dialects`, assert that they 42 | /// return the same result, and return that result. 43 | pub fn one_of_identical_results(&self, f: F) -> T 44 | where 45 | F: Fn(&dyn Dialect) -> T, 46 | { 47 | let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect))); 48 | parse_results 49 | .fold(None, |s, (dialect, parsed)| { 50 | if let Some((prev_dialect, prev_parsed)) = s { 51 | assert_eq!( 52 | prev_parsed, parsed, 53 | "Parse results with {:?} are different from {:?}", 54 | prev_dialect, dialect 55 | ); 56 | } 57 | Some((dialect, parsed)) 58 | }) 59 | .unwrap() 60 | .1 61 | } 62 | 63 | pub fn run_parser_method(&self, sql: &str, f: F) -> T 64 | where 65 | F: Fn(&mut Parser) -> T, 66 | { 67 | self.one_of_identical_results(|dialect| { 68 | let mut tokenizer = Tokenizer::new(dialect, sql); 69 | let tokens = tokenizer.tokenize().unwrap(); 70 | f(&mut Parser::new(tokens, dialect)) 71 | }) 72 | } 73 | 74 | pub fn parse_sql_statements(&self, sql: &str) -> Result, ParserError> { 75 | self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql)) 76 | // To fail the `ensure_multiple_dialects_are_tested` test: 77 | // Parser::parse_sql(&**self.dialects.first().unwrap(), sql) 78 | } 79 | 80 | /// Ensures that `sql` parses as a single statement and returns it. 81 | /// If non-empty `canonical` SQL representation is provided, 82 | /// additionally asserts that parsing `sql` results in the same parse 83 | /// tree as parsing `canonical`, and that serializing it back to string 84 | /// results in the `canonical` representation. 85 | pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement { 86 | let mut statements = self.parse_sql_statements(sql).unwrap(); 87 | assert_eq!(statements.len(), 1); 88 | 89 | if !canonical.is_empty() && sql != canonical { 90 | assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements); 91 | } 92 | 93 | let only_statement = statements.pop().unwrap(); 94 | if !canonical.is_empty() { 95 | assert_eq!(canonical, only_statement.to_string()) 96 | } 97 | only_statement 98 | } 99 | 100 | /// Ensures that `sql` parses as a single [Statement], and is not modified 101 | /// after a serialization round-trip. 102 | pub fn verified_stmt(&self, query: &str) -> Statement { 103 | self.one_statement_parses_to(query, query) 104 | } 105 | 106 | /// Ensures that `sql` parses as a single [Query], and is not modified 107 | /// after a serialization round-trip. 108 | pub fn verified_query(&self, sql: &str) -> Query { 109 | match self.verified_stmt(sql) { 110 | Statement::Query(query) => *query, 111 | _ => panic!("Expected Query"), 112 | } 113 | } 114 | 115 | /// Ensures that `sql` parses as a single [Select], and is not modified 116 | /// after a serialization round-trip. 117 | pub fn verified_only_select(&self, query: &str) -> Select { 118 | match self.verified_query(query).body { 119 | SetExpr::Select(s) => *s, 120 | _ => panic!("Expected SetExpr::Select"), 121 | } 122 | } 123 | 124 | /// Ensures that `sql` parses as an expression, and is not modified 125 | /// after a serialization round-trip. 126 | pub fn verified_expr(&self, sql: &str) -> Expr { 127 | let ast = self 128 | .run_parser_method(sql, |parser| parser.parse_expr()) 129 | .unwrap(); 130 | assert_eq!(sql, &ast.to_string(), "round-tripping without changes"); 131 | ast 132 | } 133 | } 134 | 135 | pub fn all_dialects() -> TestedDialects { 136 | TestedDialects { 137 | dialects: vec![ 138 | Box::new(GenericDialect {}), 139 | Box::new(PostgreSqlDialect {}), 140 | Box::new(MsSqlDialect {}), 141 | Box::new(AnsiDialect {}), 142 | Box::new(SnowflakeDialect {}), 143 | Box::new(HiveDialect {}), 144 | ], 145 | } 146 | } 147 | 148 | pub fn only(v: impl IntoIterator) -> T { 149 | let mut iter = v.into_iter(); 150 | if let (Some(item), None) = (iter.next(), iter.next()) { 151 | item 152 | } else { 153 | panic!("only called on collection without exactly one item") 154 | } 155 | } 156 | 157 | pub fn expr_from_projection(item: &SelectItem) -> &Expr { 158 | match item { 159 | SelectItem::UnnamedExpr(expr) => expr, 160 | _ => panic!("Expected UnnamedExpr"), 161 | } 162 | } 163 | 164 | pub fn number(n: &'static str) -> Value { 165 | Value::Number(n.parse().unwrap(), false) 166 | } 167 | 168 | pub fn table_alias(name: impl Into) -> Option { 169 | Some(TableAlias { 170 | name: Ident::new(name), 171 | columns: vec![], 172 | }) 173 | } 174 | 175 | pub fn table(name: impl Into) -> TableFactor { 176 | TableFactor::Table { 177 | name: ObjectName(vec![Ident::new(name.into())]), 178 | alias: None, 179 | args: vec![], 180 | with_hints: vec![], 181 | } 182 | } 183 | 184 | pub fn join(relation: TableFactor) -> Join { 185 | Join { 186 | relation, 187 | join_operator: JoinOperator::Inner(JoinConstraint::Natural), 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /models-parser/tests/hive.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #![warn(clippy::all)] 14 | 15 | //! Test SQL syntax specific to Hive. The parser based on the generic dialect 16 | //! is also tested (on the inputs it can handle). 17 | 18 | use models_parser::dialect::HiveDialect; 19 | use models_parser::test_utils::*; 20 | 21 | #[test] 22 | fn parse_table_create() { 23 | let sql = r#"CREATE TABLE IF NOT EXISTS db.table (a BIGINT, b STRING, c TIMESTAMP) PARTITIONED BY (d STRING, e TIMESTAMP) STORED AS ORC LOCATION 's3://...' TBLPROPERTIES ("prop" = "2", "asdf" = '1234', 'asdf' = "1234", "asdf" = 2)"#; 24 | let iof = r#"CREATE TABLE IF NOT EXISTS db.table (a BIGINT, b STRING, c TIMESTAMP) PARTITIONED BY (d STRING, e TIMESTAMP) STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.orc.OrcInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat' LOCATION 's3://...'"#; 25 | 26 | hive().verified_stmt(sql); 27 | hive().verified_stmt(iof); 28 | } 29 | 30 | #[test] 31 | fn parse_insert_overwrite() { 32 | let insert_partitions = r#"INSERT OVERWRITE TABLE db.new_table PARTITION (a = '1', b) SELECT a, b, c FROM db.table"#; 33 | hive().verified_stmt(insert_partitions); 34 | } 35 | 36 | #[test] 37 | fn test_truncate() { 38 | let truncate = r#"TRUNCATE TABLE db.table"#; 39 | hive().verified_stmt(truncate); 40 | } 41 | 42 | #[test] 43 | fn parse_analyze() { 44 | let analyze = r#"ANALYZE TABLE db.table_name PARTITION (a = '1234', b) COMPUTE STATISTICS NOSCAN CACHE METADATA"#; 45 | hive().verified_stmt(analyze); 46 | } 47 | 48 | #[test] 49 | fn parse_analyze_for_columns() { 50 | let analyze = 51 | r#"ANALYZE TABLE db.table_name PARTITION (a = '1234', b) COMPUTE STATISTICS FOR COLUMNS"#; 52 | hive().verified_stmt(analyze); 53 | } 54 | 55 | #[test] 56 | fn parse_msck() { 57 | let msck = r#"MSCK REPAIR TABLE db.table_name ADD PARTITIONS"#; 58 | let msck2 = r#"MSCK REPAIR TABLE db.table_name"#; 59 | hive().verified_stmt(msck); 60 | hive().verified_stmt(msck2); 61 | } 62 | 63 | #[test] 64 | fn parse_set() { 65 | let set = "SET HIVEVAR:name = a, b, c_d"; 66 | hive().verified_stmt(set); 67 | } 68 | 69 | #[test] 70 | fn test_spaceship() { 71 | let spaceship = "SELECT * FROM db.table WHERE a <=> b"; 72 | hive().verified_stmt(spaceship); 73 | } 74 | 75 | #[test] 76 | fn parse_with_cte() { 77 | let with = "WITH a AS (SELECT * FROM b) INSERT INTO TABLE db.table_table PARTITION (a) SELECT * FROM b"; 78 | hive().verified_stmt(with); 79 | } 80 | 81 | #[test] 82 | fn drop_table_purge() { 83 | let purge = "DROP TABLE db.table_name PURGE"; 84 | hive().verified_stmt(purge); 85 | } 86 | 87 | #[test] 88 | fn create_table_like() { 89 | let like = "CREATE TABLE db.table_name LIKE db.other_table"; 90 | hive().verified_stmt(like); 91 | } 92 | 93 | // Turning off this test until we can parse identifiers starting with numbers :( 94 | #[test] 95 | fn test_identifier() { 96 | let between = "SELECT a AS 3_barrr_asdf FROM db.table_name"; 97 | hive().verified_stmt(between); 98 | } 99 | 100 | #[test] 101 | fn test_alter_partition() { 102 | let alter = "ALTER TABLE db.table PARTITION (a = 2) RENAME TO PARTITION (a = 1)"; 103 | hive().verified_stmt(alter); 104 | } 105 | 106 | #[test] 107 | fn test_add_partition() { 108 | let add = "ALTER TABLE db.table ADD IF NOT EXISTS PARTITION (a = 'asdf', b = 2)"; 109 | hive().verified_stmt(add); 110 | } 111 | 112 | #[test] 113 | fn test_drop_partition() { 114 | let drop = "ALTER TABLE db.table DROP PARTITION (a = 1)"; 115 | hive().verified_stmt(drop); 116 | } 117 | 118 | #[test] 119 | fn test_drop_if_exists() { 120 | let drop = "ALTER TABLE db.table DROP IF EXISTS PARTITION (a = 'b', c = 'd')"; 121 | hive().verified_stmt(drop); 122 | } 123 | 124 | #[test] 125 | fn test_cluster_by() { 126 | let cluster = "SELECT a FROM db.table CLUSTER BY a, b"; 127 | hive().verified_stmt(cluster); 128 | } 129 | 130 | #[test] 131 | fn test_distribute_by() { 132 | let cluster = "SELECT a FROM db.table DISTRIBUTE BY a, b"; 133 | hive().verified_stmt(cluster); 134 | } 135 | 136 | #[test] 137 | fn no_join_condition() { 138 | let join = "SELECT a, b FROM db.table_name JOIN a"; 139 | hive().verified_stmt(join); 140 | } 141 | 142 | #[test] 143 | fn columns_after_partition() { 144 | let query = "INSERT INTO db.table_name PARTITION (a, b) (c, d) SELECT a, b, c, d FROM db.table"; 145 | hive().verified_stmt(query); 146 | } 147 | 148 | #[test] 149 | fn long_numerics() { 150 | let query = r#"SELECT MIN(MIN(10, 5), 1L) AS a"#; 151 | hive().verified_stmt(query); 152 | } 153 | 154 | #[test] 155 | fn decimal_precision() { 156 | let query = "SELECT CAST(a AS DECIMAL(18,2)) FROM db.table"; 157 | let expected = "SELECT CAST(a AS NUMERIC(18,2)) FROM db.table"; 158 | hive().one_statement_parses_to(query, expected); 159 | } 160 | 161 | #[test] 162 | fn create_temp_table() { 163 | let query = "CREATE TEMPORARY TABLE db.table (a INTEGER NOT NULL)"; 164 | let query2 = "CREATE TEMP TABLE db.table (a INTEGER NOT NULL)"; 165 | 166 | hive().verified_stmt(query); 167 | hive().one_statement_parses_to(query2, query); 168 | } 169 | 170 | #[test] 171 | fn create_local_directory() { 172 | let query = 173 | "INSERT OVERWRITE LOCAL DIRECTORY '/home/blah' STORED AS TEXTFILE SELECT * FROM db.table"; 174 | hive().verified_stmt(query); 175 | } 176 | 177 | #[test] 178 | fn lateral_view() { 179 | let view = "SELECT a FROM db.table LATERAL VIEW explode(a) t AS j, P LATERAL VIEW OUTER explode(a) t AS a, b WHERE a = 1"; 180 | hive().verified_stmt(view); 181 | } 182 | 183 | #[test] 184 | fn sort_by() { 185 | let sort_by = "SELECT * FROM db.table SORT BY a"; 186 | hive().verified_stmt(sort_by); 187 | } 188 | 189 | #[test] 190 | fn rename_table() { 191 | let rename = "ALTER TABLE db.table_name RENAME TO db.table_2"; 192 | hive().verified_stmt(rename); 193 | } 194 | 195 | #[test] 196 | fn map_access() { 197 | let rename = "SELECT a.b[\"asdf\"] FROM db.table WHERE a = 2"; 198 | hive().verified_stmt(rename); 199 | } 200 | 201 | #[test] 202 | fn from_cte() { 203 | let rename = 204 | "WITH cte AS (SELECT * FROM a.b) FROM cte INSERT INTO TABLE a.b PARTITION (a) SELECT *"; 205 | println!("{}", hive().verified_stmt(rename)); 206 | } 207 | 208 | fn hive() -> TestedDialects { 209 | TestedDialects { 210 | dialects: vec![Box::new(HiveDialect {})], 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /models-parser/tests/mssql.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #![warn(clippy::all)] 14 | //! Test SQL syntax specific to Microsoft's T-SQL. The parser based on the 15 | //! generic dialect is also tested (on the inputs it can handle). 16 | 17 | #[macro_use] 18 | mod test_utils; 19 | use test_utils::*; 20 | 21 | use models_parser::ast::*; 22 | use models_parser::dialect::{GenericDialect, MsSqlDialect}; 23 | 24 | #[test] 25 | fn parse_mssql_identifiers() { 26 | let sql = "SELECT @@version, _foo$123 FROM ##temp"; 27 | let select = ms_and_generic().verified_only_select(sql); 28 | assert_eq!( 29 | &Expr::Identifier(Ident::new("@@version")), 30 | expr_from_projection(&select.projection[0]), 31 | ); 32 | assert_eq!( 33 | &Expr::Identifier(Ident::new("_foo$123")), 34 | expr_from_projection(&select.projection[1]), 35 | ); 36 | assert_eq!(2, select.projection.len()); 37 | match &only(&select.from).relation { 38 | TableFactor::Table { name, .. } => { 39 | assert_eq!("##temp".to_string(), name.to_string()); 40 | } 41 | _ => unreachable!(), 42 | }; 43 | } 44 | 45 | #[test] 46 | fn parse_mssql_single_quoted_aliases() { 47 | let _ = ms_and_generic().one_statement_parses_to("SELECT foo 'alias'", "SELECT foo AS 'alias'"); 48 | } 49 | 50 | #[test] 51 | fn parse_mssql_delimited_identifiers() { 52 | let _ = ms().one_statement_parses_to( 53 | "SELECT [a.b!] [FROM] FROM foo [WHERE]", 54 | "SELECT [a.b!] AS [FROM] FROM foo AS [WHERE]", 55 | ); 56 | } 57 | 58 | #[test] 59 | fn parse_mssql_apply_join() { 60 | let _ = ms_and_generic().verified_only_select( 61 | "SELECT * FROM sys.dm_exec_query_stats AS deqs \ 62 | CROSS APPLY sys.dm_exec_query_plan(deqs.plan_handle)", 63 | ); 64 | let _ = ms_and_generic().verified_only_select( 65 | "SELECT * FROM sys.dm_exec_query_stats AS deqs \ 66 | OUTER APPLY sys.dm_exec_query_plan(deqs.plan_handle)", 67 | ); 68 | let _ = ms_and_generic().verified_only_select( 69 | "SELECT * FROM foo \ 70 | OUTER APPLY (SELECT foo.x + 1) AS bar", 71 | ); 72 | } 73 | 74 | #[test] 75 | fn parse_mssql_top_paren() { 76 | let sql = "SELECT TOP (5) * FROM foo"; 77 | let select = ms_and_generic().verified_only_select(sql); 78 | let top = select.top.unwrap(); 79 | assert_eq!(Some(Expr::Value(number("5"))), top.quantity); 80 | assert!(!top.percent); 81 | } 82 | 83 | #[test] 84 | fn parse_mssql_top_percent() { 85 | let sql = "SELECT TOP (5) PERCENT * FROM foo"; 86 | let select = ms_and_generic().verified_only_select(sql); 87 | let top = select.top.unwrap(); 88 | assert_eq!(Some(Expr::Value(number("5"))), top.quantity); 89 | assert!(top.percent); 90 | } 91 | 92 | #[test] 93 | fn parse_mssql_top_with_ties() { 94 | let sql = "SELECT TOP (5) WITH TIES * FROM foo"; 95 | let select = ms_and_generic().verified_only_select(sql); 96 | let top = select.top.unwrap(); 97 | assert_eq!(Some(Expr::Value(number("5"))), top.quantity); 98 | assert!(top.with_ties); 99 | } 100 | 101 | #[test] 102 | fn parse_mssql_top_percent_with_ties() { 103 | let sql = "SELECT TOP (10) PERCENT WITH TIES * FROM foo"; 104 | let select = ms_and_generic().verified_only_select(sql); 105 | let top = select.top.unwrap(); 106 | assert_eq!(Some(Expr::Value(number("10"))), top.quantity); 107 | assert!(top.percent); 108 | } 109 | 110 | #[test] 111 | fn parse_mssql_top() { 112 | let sql = "SELECT TOP 5 bar, baz FROM foo"; 113 | let _ = ms_and_generic().one_statement_parses_to(sql, "SELECT TOP (5) bar, baz FROM foo"); 114 | } 115 | 116 | #[test] 117 | fn parse_mssql_bin_literal() { 118 | let _ = ms_and_generic().one_statement_parses_to("SELECT 0xdeadBEEF", "SELECT X'deadBEEF'"); 119 | } 120 | 121 | fn ms() -> TestedDialects { 122 | TestedDialects { 123 | dialects: vec![Box::new(MsSqlDialect {})], 124 | } 125 | } 126 | fn ms_and_generic() -> TestedDialects { 127 | TestedDialects { 128 | dialects: vec![Box::new(MsSqlDialect {}), Box::new(GenericDialect {})], 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/1.sql: -------------------------------------------------------------------------------- 1 | select 2 | l_returnflag, 3 | l_linestatus, 4 | sum(l_quantity) as sum_qty, 5 | sum(l_extendedprice) as sum_base_price, 6 | sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, 7 | sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, 8 | avg(l_quantity) as avg_qty, 9 | avg(l_extendedprice) as avg_price, 10 | avg(l_discount) as avg_disc, 11 | count(*) as count_order 12 | from 13 | lineitem 14 | where 15 | l_shipdate <= date '1998-12-01' - interval '90' day (3) 16 | group by 17 | l_returnflag, 18 | l_linestatus 19 | order by 20 | l_returnflag, 21 | l_linestatus; 22 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/10.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | c_custkey, 6 | c_name, 7 | sum(l_extendedprice * (1 - l_discount)) as revenue, 8 | c_acctbal, 9 | n_name, 10 | c_address, 11 | c_phone, 12 | c_comment 13 | from 14 | customer, 15 | orders, 16 | lineitem, 17 | nation 18 | where 19 | c_custkey = o_custkey 20 | and l_orderkey = o_orderkey 21 | and o_orderdate >= date '1993-10-01' 22 | and o_orderdate < date '1993-10-01' + interval '3' month 23 | and l_returnflag = 'R' 24 | and c_nationkey = n_nationkey 25 | group by 26 | c_custkey, 27 | c_name, 28 | c_acctbal, 29 | c_phone, 30 | n_name, 31 | c_address, 32 | c_comment 33 | order by 34 | revenue desc; 35 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/11.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | ps_partkey, 6 | sum(ps_supplycost * ps_availqty) as value 7 | from 8 | partsupp, 9 | supplier, 10 | nation 11 | where 12 | ps_suppkey = s_suppkey 13 | and s_nationkey = n_nationkey 14 | and n_name = 'GERMANY' 15 | group by 16 | ps_partkey having 17 | sum(ps_supplycost * ps_availqty) > ( 18 | select 19 | sum(ps_supplycost * ps_availqty) * 0.0001000000 20 | from 21 | partsupp, 22 | supplier, 23 | nation 24 | where 25 | ps_suppkey = s_suppkey 26 | and s_nationkey = n_nationkey 27 | and n_name = 'GERMANY' 28 | ) 29 | order by 30 | value desc; 31 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/12.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | l_shipmode, 6 | sum(case 7 | when o_orderpriority = '1-URGENT' 8 | or o_orderpriority = '2-HIGH' 9 | then 1 10 | else 0 11 | end) as high_line_count, 12 | sum(case 13 | when o_orderpriority <> '1-URGENT' 14 | and o_orderpriority <> '2-HIGH' 15 | then 1 16 | else 0 17 | end) as low_line_count 18 | from 19 | orders, 20 | lineitem 21 | where 22 | o_orderkey = l_orderkey 23 | and l_shipmode in ('MAIL', 'SHIP') 24 | and l_commitdate < l_receiptdate 25 | and l_shipdate < l_commitdate 26 | and l_receiptdate >= date '1994-01-01' 27 | and l_receiptdate < date '1994-01-01' + interval '1' year 28 | group by 29 | l_shipmode 30 | order by 31 | l_shipmode; 32 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/13.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | c_count, 6 | count(*) as custdist 7 | from 8 | ( 9 | select 10 | c_custkey, 11 | count(o_orderkey) 12 | from 13 | customer left outer join orders on 14 | c_custkey = o_custkey 15 | and o_comment not like '%special%requests%' 16 | group by 17 | c_custkey 18 | ) as c_orders (c_custkey, c_count) 19 | group by 20 | c_count 21 | order by 22 | custdist desc, 23 | c_count desc; 24 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/14.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | 100.00 * sum(case 6 | when p_type like 'PROMO%' 7 | then l_extendedprice * (1 - l_discount) 8 | else 0 9 | end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue 10 | from 11 | lineitem, 12 | part 13 | where 14 | l_partkey = p_partkey 15 | and l_shipdate >= date '1995-09-01' 16 | and l_shipdate < date '1995-09-01' + interval '1' month; 17 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/15.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | create view revenue0 (supplier_no, total_revenue) as 4 | select 5 | l_suppkey, 6 | sum(l_extendedprice * (1 - l_discount)) 7 | from 8 | lineitem 9 | where 10 | l_shipdate >= date '1996-01-01' 11 | and l_shipdate < date '1996-01-01' + interval '3' month 12 | group by 13 | l_suppkey; 14 | 15 | 16 | select 17 | s_suppkey, 18 | s_name, 19 | s_address, 20 | s_phone, 21 | total_revenue 22 | from 23 | supplier, 24 | revenue0 25 | where 26 | s_suppkey = supplier_no 27 | and total_revenue = ( 28 | select 29 | max(total_revenue) 30 | from 31 | revenue0 32 | ) 33 | order by 34 | s_suppkey; 35 | 36 | drop view revenue0; 37 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/16.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | p_brand, 6 | p_type, 7 | p_size, 8 | count(distinct ps_suppkey) as supplier_cnt 9 | from 10 | partsupp, 11 | part 12 | where 13 | p_partkey = ps_partkey 14 | and p_brand <> 'Brand#45' 15 | and p_type not like 'MEDIUM POLISHED%' 16 | and p_size in (49, 14, 23, 45, 19, 3, 36, 9) 17 | and ps_suppkey not in ( 18 | select 19 | s_suppkey 20 | from 21 | supplier 22 | where 23 | s_comment like '%Customer%Complaints%' 24 | ) 25 | group by 26 | p_brand, 27 | p_type, 28 | p_size 29 | order by 30 | supplier_cnt desc, 31 | p_brand, 32 | p_type, 33 | p_size; 34 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/17.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | sum(l_extendedprice) / 7.0 as avg_yearly 6 | from 7 | lineitem, 8 | part 9 | where 10 | p_partkey = l_partkey 11 | and p_brand = 'Brand#23' 12 | and p_container = 'MED BOX' 13 | and l_quantity < ( 14 | select 15 | 0.2 * avg(l_quantity) 16 | from 17 | lineitem 18 | where 19 | l_partkey = p_partkey 20 | ); 21 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/18.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | c_name, 6 | c_custkey, 7 | o_orderkey, 8 | o_orderdate, 9 | o_totalprice, 10 | sum(l_quantity) 11 | from 12 | customer, 13 | orders, 14 | lineitem 15 | where 16 | o_orderkey in ( 17 | select 18 | l_orderkey 19 | from 20 | lineitem 21 | group by 22 | l_orderkey having 23 | sum(l_quantity) > 300 24 | ) 25 | and c_custkey = o_custkey 26 | and o_orderkey = l_orderkey 27 | group by 28 | c_name, 29 | c_custkey, 30 | o_orderkey, 31 | o_orderdate, 32 | o_totalprice 33 | order by 34 | o_totalprice desc, 35 | o_orderdate; 36 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/19.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | sum(l_extendedprice* (1 - l_discount)) as revenue 6 | from 7 | lineitem, 8 | part 9 | where 10 | ( 11 | p_partkey = l_partkey 12 | and p_brand = 'Brand#12' 13 | and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') 14 | and l_quantity >= 1 and l_quantity <= 1 + 10 15 | and p_size between 1 and 5 16 | and l_shipmode in ('AIR', 'AIR REG') 17 | and l_shipinstruct = 'DELIVER IN PERSON' 18 | ) 19 | or 20 | ( 21 | p_partkey = l_partkey 22 | and p_brand = 'Brand#23' 23 | and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') 24 | and l_quantity >= 10 and l_quantity <= 10 + 10 25 | and p_size between 1 and 10 26 | and l_shipmode in ('AIR', 'AIR REG') 27 | and l_shipinstruct = 'DELIVER IN PERSON' 28 | ) 29 | or 30 | ( 31 | p_partkey = l_partkey 32 | and p_brand = 'Brand#34' 33 | and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') 34 | and l_quantity >= 20 and l_quantity <= 20 + 10 35 | and p_size between 1 and 15 36 | and l_shipmode in ('AIR', 'AIR REG') 37 | and l_shipinstruct = 'DELIVER IN PERSON' 38 | ); 39 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/2.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | s_acctbal, 6 | s_name, 7 | n_name, 8 | p_partkey, 9 | p_mfgr, 10 | s_address, 11 | s_phone, 12 | s_comment 13 | from 14 | part, 15 | supplier, 16 | partsupp, 17 | nation, 18 | region 19 | where 20 | p_partkey = ps_partkey 21 | and s_suppkey = ps_suppkey 22 | and p_size = 15 23 | and p_type like '%BRASS' 24 | and s_nationkey = n_nationkey 25 | and n_regionkey = r_regionkey 26 | and r_name = 'EUROPE' 27 | and ps_supplycost = ( 28 | select 29 | min(ps_supplycost) 30 | from 31 | partsupp, 32 | supplier, 33 | nation, 34 | region 35 | where 36 | p_partkey = ps_partkey 37 | and s_suppkey = ps_suppkey 38 | and s_nationkey = n_nationkey 39 | and n_regionkey = r_regionkey 40 | and r_name = 'EUROPE' 41 | ) 42 | order by 43 | s_acctbal desc, 44 | n_name, 45 | s_name, 46 | p_partkey; 47 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/20.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | s_name, 6 | s_address 7 | from 8 | supplier, 9 | nation 10 | where 11 | s_suppkey in ( 12 | select 13 | ps_suppkey 14 | from 15 | partsupp 16 | where 17 | ps_partkey in ( 18 | select 19 | p_partkey 20 | from 21 | part 22 | where 23 | p_name like 'forest%' 24 | ) 25 | and ps_availqty > ( 26 | select 27 | 0.5 * sum(l_quantity) 28 | from 29 | lineitem 30 | where 31 | l_partkey = ps_partkey 32 | and l_suppkey = ps_suppkey 33 | and l_shipdate >= date '1994-01-01' 34 | and l_shipdate < date '1994-01-01' + interval '1' year 35 | ) 36 | ) 37 | and s_nationkey = n_nationkey 38 | and n_name = 'CANADA' 39 | order by 40 | s_name; 41 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/21.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | s_name, 6 | count(*) as numwait 7 | from 8 | supplier, 9 | lineitem l1, 10 | orders, 11 | nation 12 | where 13 | s_suppkey = l1.l_suppkey 14 | and o_orderkey = l1.l_orderkey 15 | and o_orderstatus = 'F' 16 | and l1.l_receiptdate > l1.l_commitdate 17 | and exists ( 18 | select 19 | * 20 | from 21 | lineitem l2 22 | where 23 | l2.l_orderkey = l1.l_orderkey 24 | and l2.l_suppkey <> l1.l_suppkey 25 | ) 26 | and not exists ( 27 | select 28 | * 29 | from 30 | lineitem l3 31 | where 32 | l3.l_orderkey = l1.l_orderkey 33 | and l3.l_suppkey <> l1.l_suppkey 34 | and l3.l_receiptdate > l3.l_commitdate 35 | ) 36 | and s_nationkey = n_nationkey 37 | and n_name = 'SAUDI ARABIA' 38 | group by 39 | s_name 40 | order by 41 | numwait desc, 42 | s_name; 43 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/22.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | cntrycode, 6 | count(*) as numcust, 7 | sum(c_acctbal) as totacctbal 8 | from 9 | ( 10 | select 11 | substring(c_phone from 1 for 2) as cntrycode, 12 | c_acctbal 13 | from 14 | customer 15 | where 16 | substring(c_phone from 1 for 2) in 17 | ('13', '31', '23', '29', '30', '18', '17') 18 | and c_acctbal > ( 19 | select 20 | avg(c_acctbal) 21 | from 22 | customer 23 | where 24 | c_acctbal > 0.00 25 | and substring(c_phone from 1 for 2) in 26 | ('13', '31', '23', '29', '30', '18', '17') 27 | ) 28 | and not exists ( 29 | select 30 | * 31 | from 32 | orders 33 | where 34 | o_custkey = c_custkey 35 | ) 36 | ) as custsale 37 | group by 38 | cntrycode 39 | order by 40 | cntrycode; 41 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/3.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | l_orderkey, 6 | sum(l_extendedprice * (1 - l_discount)) as revenue, 7 | o_orderdate, 8 | o_shippriority 9 | from 10 | customer, 11 | orders, 12 | lineitem 13 | where 14 | c_mktsegment = 'BUILDING' 15 | and c_custkey = o_custkey 16 | and l_orderkey = o_orderkey 17 | and o_orderdate < date '1995-03-15' 18 | and l_shipdate > date '1995-03-15' 19 | group by 20 | l_orderkey, 21 | o_orderdate, 22 | o_shippriority 23 | order by 24 | revenue desc, 25 | o_orderdate; 26 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/4.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | o_orderpriority, 6 | count(*) as order_count 7 | from 8 | orders 9 | where 10 | o_orderdate >= date '1993-07-01' 11 | and o_orderdate < date '1993-07-01' + interval '3' month 12 | and exists ( 13 | select 14 | * 15 | from 16 | lineitem 17 | where 18 | l_orderkey = o_orderkey 19 | and l_commitdate < l_receiptdate 20 | ) 21 | group by 22 | o_orderpriority 23 | order by 24 | o_orderpriority; 25 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/5.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | n_name, 6 | sum(l_extendedprice * (1 - l_discount)) as revenue 7 | from 8 | customer, 9 | orders, 10 | lineitem, 11 | supplier, 12 | nation, 13 | region 14 | where 15 | c_custkey = o_custkey 16 | and l_orderkey = o_orderkey 17 | and l_suppkey = s_suppkey 18 | and c_nationkey = s_nationkey 19 | and s_nationkey = n_nationkey 20 | and n_regionkey = r_regionkey 21 | and r_name = 'ASIA' 22 | and o_orderdate >= date '1994-01-01' 23 | and o_orderdate < date '1994-01-01' + interval '1' year 24 | group by 25 | n_name 26 | order by 27 | revenue desc; 28 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/6.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | sum(l_extendedprice * l_discount) as revenue 6 | from 7 | lineitem 8 | where 9 | l_shipdate >= date '1994-01-01' 10 | and l_shipdate < date '1994-01-01' + interval '1' year 11 | and l_discount between .06 - 0.01 and .06 + 0.01 12 | and l_quantity < 24; 13 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/7.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | supp_nation, 6 | cust_nation, 7 | l_year, 8 | sum(volume) as revenue 9 | from 10 | ( 11 | select 12 | n1.n_name as supp_nation, 13 | n2.n_name as cust_nation, 14 | extract(year from l_shipdate) as l_year, 15 | l_extendedprice * (1 - l_discount) as volume 16 | from 17 | supplier, 18 | lineitem, 19 | orders, 20 | customer, 21 | nation n1, 22 | nation n2 23 | where 24 | s_suppkey = l_suppkey 25 | and o_orderkey = l_orderkey 26 | and c_custkey = o_custkey 27 | and s_nationkey = n1.n_nationkey 28 | and c_nationkey = n2.n_nationkey 29 | and ( 30 | (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') 31 | or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') 32 | ) 33 | and l_shipdate between date '1995-01-01' and date '1996-12-31' 34 | ) as shipping 35 | group by 36 | supp_nation, 37 | cust_nation, 38 | l_year 39 | order by 40 | supp_nation, 41 | cust_nation, 42 | l_year; 43 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/8.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | o_year, 6 | sum(case 7 | when nation = 'BRAZIL' then volume 8 | else 0 9 | end) / sum(volume) as mkt_share 10 | from 11 | ( 12 | select 13 | extract(year from o_orderdate) as o_year, 14 | l_extendedprice * (1 - l_discount) as volume, 15 | n2.n_name as nation 16 | from 17 | part, 18 | supplier, 19 | lineitem, 20 | orders, 21 | customer, 22 | nation n1, 23 | nation n2, 24 | region 25 | where 26 | p_partkey = l_partkey 27 | and s_suppkey = l_suppkey 28 | and l_orderkey = o_orderkey 29 | and o_custkey = c_custkey 30 | and c_nationkey = n1.n_nationkey 31 | and n1.n_regionkey = r_regionkey 32 | and r_name = 'AMERICA' 33 | and s_nationkey = n2.n_nationkey 34 | and o_orderdate between date '1995-01-01' and date '1996-12-31' 35 | and p_type = 'ECONOMY ANODIZED STEEL' 36 | ) as all_nations 37 | group by 38 | o_year 39 | order by 40 | o_year; 41 | -------------------------------------------------------------------------------- /models-parser/tests/queries/tpch/9.sql: -------------------------------------------------------------------------------- 1 | -- using default substitutions 2 | 3 | 4 | select 5 | nation, 6 | o_year, 7 | sum(amount) as sum_profit 8 | from 9 | ( 10 | select 11 | n_name as nation, 12 | extract(year from o_orderdate) as o_year, 13 | l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount 14 | from 15 | part, 16 | supplier, 17 | lineitem, 18 | partsupp, 19 | orders, 20 | nation 21 | where 22 | s_suppkey = l_suppkey 23 | and ps_suppkey = l_suppkey 24 | and ps_partkey = l_partkey 25 | and p_partkey = l_partkey 26 | and o_orderkey = l_orderkey 27 | and s_nationkey = n_nationkey 28 | and p_name like '%green%' 29 | ) as profit 30 | group by 31 | nation, 32 | o_year 33 | order by 34 | nation, 35 | o_year desc; 36 | -------------------------------------------------------------------------------- /models-parser/tests/regression.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #![warn(clippy::all)] 14 | 15 | use models_parser::dialect::GenericDialect; 16 | use models_parser::parser::Parser; 17 | 18 | macro_rules! tpch_tests { 19 | ($($name:ident: $value:expr,)*) => { 20 | const QUERIES: &[&str] = &[ 21 | $(include_str!(concat!("queries/tpch/", $value, ".sql"))),* 22 | ]; 23 | $( 24 | 25 | #[test] 26 | fn $name() { 27 | let dialect = GenericDialect {}; 28 | let res = Parser::parse_sql(&dialect, QUERIES[$value -1]); 29 | assert!(res.is_ok()); 30 | } 31 | )* 32 | } 33 | } 34 | 35 | tpch_tests! { 36 | tpch_1: 1, 37 | tpch_2: 2, 38 | tpch_3: 3, 39 | tpch_4: 4, 40 | tpch_5: 5, 41 | tpch_6: 6, 42 | tpch_7: 7, 43 | tpch_8: 8, 44 | tpch_9: 9, 45 | tpch_10: 10, 46 | tpch_11: 11, 47 | tpch_12: 12, 48 | tpch_13: 13, 49 | tpch_14: 14, 50 | tpch_15: 15, 51 | tpch_16: 16, 52 | tpch_17: 17, 53 | tpch_18: 18, 54 | tpch_19: 19, 55 | tpch_20: 20, 56 | tpch_21: 21, 57 | tpch_22: 22, 58 | } 59 | -------------------------------------------------------------------------------- /models-parser/tests/snowflake.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #![warn(clippy::all)] 14 | //! Test SQL syntax specific to Snowflake. The parser based on the 15 | //! generic dialect is also tested (on the inputs it can handle). 16 | 17 | #[macro_use] 18 | mod test_utils; 19 | use test_utils::*; 20 | 21 | use models_parser::ast::*; 22 | use models_parser::dialect::{GenericDialect, SnowflakeDialect}; 23 | use models_parser::parser::ParserError; 24 | use models_parser::tokenizer::*; 25 | 26 | #[test] 27 | fn test_snowflake_create_table() { 28 | let sql = "CREATE TABLE _my_$table (am00unt number)"; 29 | match snowflake_and_generic().verified_stmt(sql) { 30 | Statement::CreateTable(table) => { 31 | assert_eq!("_my_$table", table.name.to_string()); 32 | } 33 | _ => unreachable!(), 34 | } 35 | } 36 | 37 | #[test] 38 | fn test_snowflake_single_line_tokenize() { 39 | let sql = "CREATE TABLE# this is a comment \ntable_1"; 40 | let dialect = SnowflakeDialect {}; 41 | let mut tokenizer = Tokenizer::new(&dialect, sql); 42 | let tokens = tokenizer.tokenize().unwrap(); 43 | 44 | let expected = vec![ 45 | Token::make_keyword("CREATE"), 46 | Token::Whitespace(Whitespace::Space), 47 | Token::make_keyword("TABLE"), 48 | Token::Whitespace(Whitespace::SingleLineComment { 49 | prefix: "#".to_string(), 50 | comment: " this is a comment \n".to_string(), 51 | }), 52 | Token::make_word("table_1", None), 53 | ]; 54 | 55 | assert_eq!(expected, tokens); 56 | 57 | let sql = "CREATE TABLE// this is a comment \ntable_1"; 58 | let mut tokenizer = Tokenizer::new(&dialect, sql); 59 | let tokens = tokenizer.tokenize().unwrap(); 60 | 61 | let expected = vec![ 62 | Token::make_keyword("CREATE"), 63 | Token::Whitespace(Whitespace::Space), 64 | Token::make_keyword("TABLE"), 65 | Token::Whitespace(Whitespace::SingleLineComment { 66 | prefix: "//".to_string(), 67 | comment: " this is a comment \n".to_string(), 68 | }), 69 | Token::make_word("table_1", None), 70 | ]; 71 | 72 | assert_eq!(expected, tokens); 73 | } 74 | 75 | #[test] 76 | fn test_sf_derived_table_in_parenthesis() { 77 | // Nesting a subquery in an extra set of parentheses is non-standard, 78 | // but supported in Snowflake SQL 79 | snowflake_and_generic().one_statement_parses_to( 80 | "SELECT * FROM ((SELECT 1) AS t)", 81 | "SELECT * FROM (SELECT 1) AS t", 82 | ); 83 | snowflake_and_generic().one_statement_parses_to( 84 | "SELECT * FROM (((SELECT 1) AS t))", 85 | "SELECT * FROM (SELECT 1) AS t", 86 | ); 87 | } 88 | 89 | #[test] 90 | fn test_single_table_in_parenthesis() { 91 | // Parenthesized table names are non-standard, but supported in Snowflake SQL 92 | snowflake_and_generic().one_statement_parses_to( 93 | "SELECT * FROM (a NATURAL JOIN (b))", 94 | "SELECT * FROM (a NATURAL JOIN b)", 95 | ); 96 | snowflake_and_generic().one_statement_parses_to( 97 | "SELECT * FROM (a NATURAL JOIN ((b)))", 98 | "SELECT * FROM (a NATURAL JOIN b)", 99 | ); 100 | } 101 | 102 | #[test] 103 | fn test_single_table_in_parenthesis_with_alias() { 104 | snowflake_and_generic().one_statement_parses_to( 105 | "SELECT * FROM (a NATURAL JOIN (b) c )", 106 | "SELECT * FROM (a NATURAL JOIN b AS c)", 107 | ); 108 | 109 | snowflake_and_generic().one_statement_parses_to( 110 | "SELECT * FROM (a NATURAL JOIN ((b)) c )", 111 | "SELECT * FROM (a NATURAL JOIN b AS c)", 112 | ); 113 | 114 | snowflake_and_generic().one_statement_parses_to( 115 | "SELECT * FROM (a NATURAL JOIN ( (b) c ) )", 116 | "SELECT * FROM (a NATURAL JOIN b AS c)", 117 | ); 118 | 119 | snowflake_and_generic().one_statement_parses_to( 120 | "SELECT * FROM (a NATURAL JOIN ( (b) as c ) )", 121 | "SELECT * FROM (a NATURAL JOIN b AS c)", 122 | ); 123 | 124 | snowflake_and_generic().one_statement_parses_to( 125 | "SELECT * FROM (a alias1 NATURAL JOIN ( (b) c ) )", 126 | "SELECT * FROM (a AS alias1 NATURAL JOIN b AS c)", 127 | ); 128 | 129 | snowflake_and_generic().one_statement_parses_to( 130 | "SELECT * FROM (a as alias1 NATURAL JOIN ( (b) as c ) )", 131 | "SELECT * FROM (a AS alias1 NATURAL JOIN b AS c)", 132 | ); 133 | 134 | let res = snowflake_and_generic().parse_sql_statements("SELECT * FROM (a NATURAL JOIN b) c"); 135 | assert_eq!( 136 | ParserError::ParserError("Expected end of statement, found: c".to_string()), 137 | res.unwrap_err() 138 | ); 139 | 140 | let res = snowflake().parse_sql_statements("SELECT * FROM (a b) c"); 141 | assert_eq!( 142 | ParserError::ParserError("duplicate alias b".to_string()), 143 | res.unwrap_err() 144 | ); 145 | } 146 | 147 | fn snowflake() -> TestedDialects { 148 | TestedDialects { 149 | dialects: vec![Box::new(SnowflakeDialect {})], 150 | } 151 | } 152 | 153 | fn snowflake_and_generic() -> TestedDialects { 154 | TestedDialects { 155 | dialects: vec![Box::new(SnowflakeDialect {}), Box::new(GenericDialect {})], 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /models-parser/tests/sqlite.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | #![warn(clippy::all)] 14 | //! Test SQL syntax specific to SQLite. The parser based on the 15 | //! generic dialect is also tested (on the inputs it can handle). 16 | 17 | #[macro_use] 18 | mod test_utils; 19 | use test_utils::*; 20 | 21 | use models_parser::ast::*; 22 | use models_parser::dialect::{GenericDialect, SQLiteDialect}; 23 | use models_parser::tokenizer::Token; 24 | 25 | #[test] 26 | fn parse_create_table_without_rowid() { 27 | let sql = "CREATE TABLE t (a INTEGER) WITHOUT ROWID"; 28 | match sqlite_and_generic().verified_stmt(sql) { 29 | Statement::CreateTable(table) => { 30 | assert!(table.without_rowid); 31 | assert_eq!("t", table.name.to_string()); 32 | } 33 | _ => unreachable!(), 34 | } 35 | } 36 | 37 | #[test] 38 | fn parse_create_virtual_table() { 39 | let sql = "CREATE VIRTUAL TABLE IF NOT EXISTS t USING module_name (arg1, arg2)"; 40 | match sqlite_and_generic().verified_stmt(sql) { 41 | Statement::CreateVirtualTable(CreateVirtualTable { 42 | name, 43 | if_not_exists: true, 44 | module_name, 45 | module_args, 46 | }) => { 47 | let args = vec![Ident::new("arg1"), Ident::new("arg2")]; 48 | assert_eq!("t", name.to_string()); 49 | assert_eq!("module_name", module_name.to_string()); 50 | assert_eq!(args, module_args); 51 | } 52 | _ => unreachable!(), 53 | } 54 | 55 | let sql = "CREATE VIRTUAL TABLE t USING module_name"; 56 | sqlite_and_generic().verified_stmt(sql); 57 | } 58 | 59 | #[test] 60 | fn parse_create_table_auto_increment() { 61 | let sql = "CREATE TABLE foo (bar INTEGER PRIMARY KEY AUTOINCREMENT)"; 62 | match sqlite_and_generic().verified_stmt(sql) { 63 | Statement::CreateTable(table) => { 64 | let name = table.name; 65 | let columns = table.columns; 66 | assert_eq!(name.to_string(), "foo"); 67 | assert_eq!( 68 | vec![ColumnDef { 69 | name: "bar".into(), 70 | data_type: DataType::Int(None), 71 | collation: None, 72 | options: vec![ 73 | ColumnOptionDef { 74 | name: None, 75 | option: ColumnOption::Unique { is_primary: true } 76 | }, 77 | ColumnOptionDef { 78 | name: None, 79 | option: ColumnOption::DialectSpecific(vec![Token::make_keyword( 80 | "AUTOINCREMENT" 81 | )]) 82 | } 83 | ], 84 | }], 85 | columns 86 | ); 87 | } 88 | _ => unreachable!(), 89 | } 90 | } 91 | 92 | #[test] 93 | fn parse_create_sqlite_quote() { 94 | let sql = "CREATE TABLE `PRIMARY` (\"KEY\" INTEGER, [INDEX] INTEGER)"; 95 | match sqlite().verified_stmt(sql) { 96 | Statement::CreateTable(table) => { 97 | let columns = table.columns; 98 | let name = table.name; 99 | assert_eq!(name.to_string(), "`PRIMARY`"); 100 | assert_eq!( 101 | vec![ 102 | ColumnDef { 103 | name: Ident::with_quote('"', "KEY"), 104 | data_type: DataType::Int(None), 105 | collation: None, 106 | options: vec![], 107 | }, 108 | ColumnDef { 109 | name: Ident::with_quote('[', "INDEX"), 110 | data_type: DataType::Int(None), 111 | collation: None, 112 | options: vec![], 113 | }, 114 | ], 115 | columns 116 | ); 117 | } 118 | _ => unreachable!(), 119 | } 120 | } 121 | 122 | fn sqlite() -> TestedDialects { 123 | TestedDialects { 124 | dialects: vec![Box::new(SQLiteDialect {})], 125 | } 126 | } 127 | 128 | fn sqlite_and_generic() -> TestedDialects { 129 | TestedDialects { 130 | // we don't have a separate SQLite dialect, so test only the generic dialect for now 131 | dialects: vec![Box::new(SQLiteDialect {}), Box::new(GenericDialect {})], 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /models-parser/tests/test_utils/mod.rs: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | // Re-export everything from `src/test_utils.rs`. 14 | pub use models_parser::test_utils::*; 15 | 16 | // For the test-only macros we take a different approach of keeping them here 17 | // rather than in the library crate. 18 | // 19 | // This is because we don't need any of them to be shared between the 20 | // integration tests (i.e. `tests/*`) and the unit tests (i.e. `src/*`), 21 | // but also because Rust doesn't scope macros to a particular module 22 | // (and while we export internal helpers as models_parser::test_utils::<...>, 23 | // expecting our users to abstain from relying on them, exporting internal 24 | // macros at the top level, like `models_parser::nest` was deemed too confusing). 25 | 26 | #[macro_export] 27 | macro_rules! nest { 28 | ($base:expr $(, $join:expr)*) => { 29 | TableFactor::NestedJoin(Box::new(TableWithJoins { 30 | relation: $base, 31 | joins: vec![$(join($join)),*] 32 | })) 33 | }; 34 | } 35 | -------------------------------------------------------------------------------- /models-proc-macro/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "models-proc-macro" 3 | version = "0.1.1" 4 | edition = "2018" 5 | license = "Apache-2.0" 6 | description = "A helper crate for `models`" 7 | 8 | [lib] 9 | proc-macro = true 10 | 11 | [features] 12 | helpers = [] 13 | 14 | [dependencies] 15 | proc-macro2 = "1.0.28" 16 | quote = "1.0.9" 17 | models-parser ={version = "0.2.0", path = "../models-parser"} 18 | syn = "1.0.74" 19 | -------------------------------------------------------------------------------- /models-proc-macro/src/getters.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; // TODO: 2 | // getters for foreign keys 3 | 4 | use super::Model; 5 | 6 | struct Getters<'a> { 7 | model: &'a Model, 8 | getters: Vec, 9 | } 10 | enum Getter { 11 | Unique { 12 | table_name: Ident, 13 | column_name: Ident, 14 | column_type: Type, 15 | }, 16 | 17 | Foreign { 18 | table_name: Ident, 19 | referred: Ident, 20 | }, 21 | } 22 | 23 | impl<'a> ToTokens for Getters<'a> { 24 | fn to_tokens(&self, tokens: &mut TokenStream2) { 25 | let ident = &self.model.name; 26 | let getters = &self.getters; 27 | tokens.extend(quote! { 28 | impl #ident { 29 | #(#getters)* 30 | } 31 | }) 32 | } 33 | } 34 | 35 | impl<'a> ToTokens for Getter { 36 | fn to_tokens(&self, tokens: &mut TokenStream2) { 37 | match &self { 38 | Self::Unique { 39 | table_name, 40 | column_name, 41 | column_type, 42 | } => { 43 | 44 | let query = format!("select * from {} where co;", table_name, column_name); 45 | tokens.extend(quote! { 46 | fn #name(&self, val: #dtype) -> ::std::result::Result, ::sqlx::Error>{ 47 | ::sqlx::query(#query). 48 | .fetch_all(&conn) 49 | .await 50 | } 51 | }); 52 | } 53 | } 54 | } 55 | } 56 | 57 | impl Getter {} 58 | -------------------------------------------------------------------------------- /models-proc-macro/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod migration_generation; 2 | // mod getters; 3 | mod model; 4 | mod prelude; 5 | use migration_generation::*; 6 | use model::*; 7 | use prelude::*; 8 | 9 | #[proc_macro_derive(Model, attributes(model, primary_key, foreign_key, unique, default))] 10 | pub fn model(input: TokenStream) -> TokenStream { 11 | let derive = parse_macro_input!(input as Model); 12 | 13 | let migrations = generate_migration(&derive.name); 14 | let template = quote! { 15 | #derive 16 | #migrations 17 | }; 18 | template.into() 19 | } 20 | -------------------------------------------------------------------------------- /models-proc-macro/src/migration_generation.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | // SQLX_MODELS_GENERATE_MIGRATION=true 3 | // MODELS_GENERATE_MIGRATIONS 4 | 5 | pub fn generate_migration(name: &Ident) -> TokenStream2 { 6 | if let Ok(value) = std::env::var("MODELS_GENERATE_MIGRATIONS") { 7 | if value.to_lowercase() == "true" { 8 | generate_migration_unchecked(name) 9 | } else { 10 | quote!() 11 | } 12 | } else { 13 | quote!() 14 | } 15 | } 16 | 17 | fn generate_migration_unchecked(name: &Ident) -> TokenStream2 { 18 | let test_name = Ident::new( 19 | &format!("__models_generate_migration_{}", name), 20 | proc_macro2::Span::call_site(), 21 | ); 22 | quote! { 23 | #[test] 24 | fn #test_name() { 25 | ::models::private::SCHEDULER.register( 26 | <#name as ::models::private::Model>::target() 27 | ); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /models-proc-macro/src/model/column/default.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use proc_macro2::Span; 3 | 4 | pub struct DefaultExpr { 5 | is_string: bool, 6 | expr: String, 7 | } 8 | 9 | impl ToTokens for DefaultExpr { 10 | fn to_tokens(&self, tokens: &mut TokenStream2) { 11 | let expr = &self.expr; 12 | if !self.is_string { 13 | tokens.extend(quote!(#expr)); 14 | } else { 15 | let expr = format!("{:?}", self.expr); 16 | let len = expr.chars().count(); 17 | let mut out: String = "'".into(); 18 | for char in expr.chars().skip(1).take(len - 2) { 19 | out.push(char); 20 | } 21 | out.push('\''); 22 | tokens.extend(quote!(#out)) 23 | } 24 | } 25 | } 26 | 27 | impl Parse for DefaultExpr { 28 | fn parse(input: parse::ParseStream) -> Result { 29 | use models_parser::{dialect::*, parser::Parser, tokenizer::*}; 30 | 31 | let content; 32 | let _paren = parenthesized!(content in input); 33 | let span = Span::call_site(); 34 | let mut is_string = false; 35 | let expr = match content.parse::() { 36 | Ok(Lit::Bool(boolean)) => boolean.value().to_string(), 37 | Ok(Lit::Int(int)) => int.to_string(), 38 | Ok(Lit::Float(float)) => float.to_string(), 39 | Ok(Lit::Str(string)) => { 40 | is_string = true; 41 | string.value() 42 | } 43 | Ok(lit) => Err(Error::new( 44 | lit.span(), 45 | "Expected string, boolean, or numeric literal", 46 | ))?, 47 | Err(err) => Err(Error::new( 48 | err.span(), 49 | "Expected string, boolean, or numeric literal", 50 | ))?, 51 | }; 52 | 53 | let mut lexer = Tokenizer::new(&GenericDialect {}, &expr); 54 | 55 | let tokens = lexer.tokenize().map_err(|err| { 56 | syn::Error::new( 57 | span, 58 | format!("Failed to tokenize default expression: {:?}", err.message), 59 | ) 60 | })?; 61 | 62 | let _ = Parser::new(tokens, &GenericDialect {}) 63 | .parse_expr() 64 | .map_err(|err| { 65 | syn::Error::new(span, format!("Failed to parse default expression: {}", err)) 66 | }); 67 | Ok(DefaultExpr { is_string, expr }) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /models-proc-macro/src/model/column/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | mod default; 3 | 4 | use default::*; 5 | 6 | pub struct Column { 7 | name: Ident, 8 | ty: Type, 9 | default: Option, 10 | } 11 | 12 | impl ToTokens for Column { 13 | fn to_tokens(&self, tokens: &mut TokenStream2) { 14 | let col_name = &self.name; 15 | let ty = &self.ty; 16 | let default = &self.default; 17 | let temp = if let Some(default) = default { 18 | quote! { 19 | __models_table.columns.push( 20 | ::models::private::Column::new_with_default( 21 | stringify!(#col_name), 22 | <#ty as ::models::types::IntoSQL>::into_sql(), 23 | <#ty as ::models::types::IntoSQL>::IS_NULLABLE, 24 | #default 25 | )); 26 | } 27 | } else { 28 | quote! { 29 | __models_table.columns.push( 30 | ::models::private::Column::new( 31 | stringify!(#col_name), 32 | <#ty as ::models::types::IntoSQL>::into_sql(), 33 | <#ty as ::models::types::IntoSQL>::IS_NULLABLE, 34 | )); 35 | } 36 | }; 37 | tokens.extend(temp); 38 | } 39 | } 40 | 41 | impl Column { 42 | pub fn new(field: &Field) -> Result { 43 | let ty = field.ty.clone(); 44 | let default = Self::get_default(field.attrs.clone())?; 45 | let name = field.ident.clone().unwrap(); 46 | Ok(Self { ty, default, name }) 47 | } 48 | 49 | fn get_default(attrs: Vec) -> Result> { 50 | for attr in attrs { 51 | if attr.path.is_ident("default") { 52 | return Ok(Some(syn::parse(attr.tokens.into())?)); 53 | } 54 | } 55 | Ok(None) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /models-proc-macro/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | mod column; 2 | mod constraint; 3 | use crate::prelude::*; 4 | use constraint::*; 5 | use Data::*; 6 | 7 | use self::column::Column; 8 | pub struct Model { 9 | pub name: Ident, 10 | name_lowercase: Ident, 11 | data: DataStruct, 12 | columns: Vec, 13 | constraints: Vec, 14 | } 15 | 16 | struct ForeignKey { 17 | tables: Vec, 18 | columns: Vec, 19 | } 20 | impl Parse for ForeignKey { 21 | fn parse(input: parse::ParseStream) -> Result { 22 | let mut out = ForeignKey { 23 | tables: vec![], 24 | columns: vec![], 25 | }; 26 | let content; 27 | let _paren = parenthesized!(content in input); 28 | while !content.is_empty() { 29 | out.tables.push(content.parse::()?); 30 | content.parse::()?; 31 | out.columns.push(content.parse::()?); 32 | } 33 | Ok(out) 34 | } 35 | } 36 | 37 | impl Parse for Model { 38 | fn parse(input: parse::ParseStream) -> Result { 39 | let input: DeriveInput = input.parse()?; 40 | let name = input.ident; 41 | let name_lowercase = Ident::new(&name.to_string().to_lowercase(), name.span()); 42 | match input.data { 43 | Struct(data) => { 44 | let mut model = Self { 45 | name, 46 | data, 47 | name_lowercase, 48 | columns: Default::default(), 49 | constraints: Default::default(), 50 | }; 51 | model.init()?; 52 | Ok(model) 53 | } 54 | _ => panic!("Sql models have to be structs, enums and unions are not supported."), 55 | } 56 | } 57 | } 58 | 59 | impl ToTokens for Model { 60 | fn to_tokens(&self, tokens: &mut TokenStream2) { 61 | let name = &self.name; 62 | let name_lowercase = &self.name_lowercase; 63 | let columns = &self.get_columns(); 64 | let constraints = &self.get_constraints(); 65 | let template = quote! { 66 | impl ::models::private::Model for #name { 67 | fn target() -> ::models::private::Table { 68 | let mut __models_table = ::models::private::Table::new(stringify!(#name_lowercase)); 69 | #columns 70 | #constraints 71 | __models_table 72 | } 73 | } 74 | }; 75 | tokens.extend(template); 76 | } 77 | } 78 | 79 | impl Model { 80 | // include 81 | fn init(&mut self) -> Result<()> { 82 | for field in &self.data.fields { 83 | let col_name = field.ident.clone().unwrap(); 84 | let constrs: Vec<_> = Constraints::from_attrs(&field.attrs)? 85 | .0 86 | .into_iter() 87 | .map(|constr| NamedConstraint { 88 | name: self.constr_name(&constr.method(), &col_name, &constr.column_names()), 89 | field_name: col_name.clone(), 90 | constr, 91 | }) 92 | .collect(); 93 | self.constraints.extend(constrs); 94 | 95 | let column = Column::new(field)?; 96 | self.columns.push(column); 97 | } 98 | Ok(()) 99 | } 100 | 101 | fn get_columns(&self) -> TokenStream2 { 102 | let columns = self.columns.iter(); 103 | quote! { 104 | #(#columns;)* 105 | } 106 | } 107 | 108 | fn get_constraints(&self) -> TokenStream2 { 109 | let columns = self 110 | .constraints 111 | .iter() 112 | .map(|constr| constr.into_tokens(&self.name)); 113 | 114 | quote! {#(#columns;)*} 115 | } 116 | 117 | pub fn constr_name( 118 | &self, 119 | method: &impl ToString, 120 | name: &impl ToString, 121 | cols: &[impl ToString], 122 | ) -> String { 123 | let mut constr_name = String::new(); 124 | constr_name += &self.name_lowercase.to_string(); 125 | constr_name += "_"; 126 | constr_name += &method.to_string(); 127 | constr_name += "_"; 128 | constr_name += &name.to_string(); 129 | 130 | for col in cols.iter() { 131 | constr_name += "_"; 132 | 133 | constr_name += &col.to_string(); 134 | } 135 | constr_name 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /models-proc-macro/src/prelude.rs: -------------------------------------------------------------------------------- 1 | pub use collections::HashMap; 2 | pub use proc_macro::{TokenStream, *}; 3 | pub use proc_macro2::TokenStream as TokenStream2; 4 | pub use quote::{quote, *}; 5 | pub use std::*; 6 | pub use syn::parse::Parse; 7 | pub use syn::{Ident, *}; 8 | -------------------------------------------------------------------------------- /models/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "models" 3 | version = "0.1.3" 4 | edition = "2018" 5 | license = "Apache-2.0" 6 | description = "A migration management library for applications using PostgresSQL, MySQL or SQLite." 7 | keywords = ["database", "postgres", "sqlite", "sql", "migration"] 8 | readme = "../README.md" 9 | authors = ["Tomas Vallotton "] 10 | 11 | [features] 12 | default = ["sqlformat"] 13 | json = ["serde", "serde_json"] 14 | sqlx-postgres = ["sqlx", "sqlx/postgres", "sqlx/json"] 15 | sqlx-mysql = ["sqlx", "sqlx/mysql", "sqlx/json"] 16 | sqlx-sqlite = ["sqlx", "sqlx/sqlite", "sqlx/json"] 17 | # postgres = [] 18 | # sqlx = [] 19 | # rusqlite = [] 20 | # tokio_postgres = [] 21 | # serde features 22 | # json = [] 23 | # binary = [] 24 | 25 | 26 | [dependencies] 27 | thiserror = "1.0.29" 28 | models-parser = { version = "0.2.0", path = "../models-parser"} 29 | models-proc-macro ={version = "0.1.1", path = "../models-proc-macro"} 30 | once_cell = "1.8.0" 31 | url = "2.2.2" 32 | sqlformat = { version = "0.1.8", optional = true } 33 | serde = { version = "1.0.130", features = ["derive"], optional = true} 34 | serde_json = {version = "1.0.68", optional = true} 35 | sqlx = {version = "0.5.9", optional = true} 36 | chrono = {version = "0.4.19", optional = true} 37 | 38 | 39 | [dev-dependencies] 40 | 41 | sqlx = {version = "0.5.9", features = ["runtime-async-std-native-tls", "postgres"] } 42 | models = {path = "../models", features = ["sqlformat", "json", "sqlx", "chrono"]} 43 | -------------------------------------------------------------------------------- /models/src/dialect.rs: -------------------------------------------------------------------------------- 1 | use self::Dialect::*; 2 | use dialect::*; 3 | use models_parser::dialect; 4 | #[derive(Clone, Copy, Debug)] 5 | pub(crate) enum Dialect { 6 | SQLite, 7 | PostgreSQL, 8 | MySQL, 9 | MsSQL, 10 | Any, 11 | } 12 | 13 | impl Dialect { 14 | pub(crate) fn requires_move(&self) -> bool { 15 | matches!(self, Dialect::SQLite | Dialect::Any) 16 | } 17 | 18 | pub(crate) fn _has_default_constr_name(&self) -> bool { 19 | matches!(self, Dialect::PostgreSQL) 20 | } 21 | 22 | pub(crate) fn supports_cascade(&self) -> bool { 23 | !matches!(self, SQLite) 24 | } 25 | } 26 | 27 | impl dialect::Dialect for Dialect { 28 | fn is_delimited_identifier_start(&self, ch: char) -> bool { 29 | match self { 30 | SQLite => SQLiteDialect {}.is_delimited_identifier_start(ch), 31 | PostgreSQL => PostgreSqlDialect {}.is_delimited_identifier_start(ch), 32 | MySQL => MySqlDialect {}.is_delimited_identifier_start(ch), 33 | MsSQL => MsSqlDialect {}.is_delimited_identifier_start(ch), 34 | Any => GenericDialect {}.is_delimited_identifier_start(ch), 35 | } 36 | } 37 | fn is_identifier_start(&self, ch: char) -> bool { 38 | match self { 39 | SQLite => SQLiteDialect {}.is_identifier_start(ch), 40 | PostgreSQL => PostgreSqlDialect {}.is_identifier_start(ch), 41 | MySQL => MySqlDialect {}.is_identifier_start(ch), 42 | MsSQL => MsSqlDialect {}.is_identifier_start(ch), 43 | Any => GenericDialect {}.is_identifier_start(ch), 44 | } 45 | } 46 | 47 | fn is_identifier_part(&self, ch: char) -> bool { 48 | match self { 49 | SQLite => SQLiteDialect {}.is_identifier_part(ch), 50 | PostgreSQL => PostgreSqlDialect {}.is_identifier_part(ch), 51 | MySQL => MySqlDialect {}.is_identifier_part(ch), 52 | MsSQL => MsSqlDialect {}.is_identifier_part(ch), 53 | Any => GenericDialect {}.is_identifier_part(ch), 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /models/src/error.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | 3 | use models_parser::parser::ParserError; 4 | 5 | use std::sync::Arc; 6 | use thiserror::Error; 7 | 8 | macro_rules! error { 9 | ($($args:expr),+) => { 10 | Error::Message(format!($($args),*)) 11 | }; 12 | } 13 | 14 | #[derive(Error, Debug, Clone)] 15 | pub enum Error { 16 | #[error("syntax error: {0}")] 17 | Syntax(#[from] ParserError), 18 | #[error("syntax error: {0}.\n found at file \"{1}\".")] 19 | SyntaxAtFile(ParserError, path::PathBuf), 20 | #[error("{0}")] 21 | Message(String), 22 | #[error("could not read or create migration file. {0}")] 23 | IO(#[from] Arc), 24 | #[error("dependency cycle detected invlonving the tables: {0:?}. help: consider removing redundant foreign key constraints. ")] 25 | Cycle(Vec), 26 | } 27 | 28 | impl Error { 29 | pub(crate) fn kind(&self) -> &'static str { 30 | match self { 31 | Self::Cycle(_) => "CycleError", 32 | Self::Message(_) => "error", 33 | Self::IO(_) => "IOError", 34 | Self::Syntax(_) => "SyntaxError", 35 | Self::SyntaxAtFile(_, _) => "SyntaxAtFile", 36 | } 37 | } 38 | 39 | pub(crate) fn as_json(&self) -> String { 40 | let err_msg = format!("{}", self); 41 | let kind = self.kind(); 42 | 43 | format!( 44 | r#"{{"kind":{kind:?},"message":{message:?}}}"#, 45 | kind = kind, 46 | message = err_msg 47 | ) 48 | } 49 | } 50 | 51 | impl From for Error { 52 | fn from(err: io::Error) -> Error { 53 | Error::IO(Arc::new(err)) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /models/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Models 2 | //! Models is a SQL migration management tool. It supports PostgreSQL, MySQL, and SQLite. 3 | //! 4 | //! 5 | //! # Quick Start 6 | //! 7 | //! install the CLI by running the following command: 8 | //! ```ignore 9 | //! $ cargo install models-cli 10 | //! ``` 11 | //! 12 | //! Now run the following command to create an environment file with the `DATABASE_URL` variable set: 13 | //! ```ignore 14 | //! $ echo "DATABASE_URL=sqlite://database.db" > .env 15 | //! ``` 16 | //! Alternatively it can be set as a environment variable with the following command: 17 | //! ```ignore 18 | //! $ export DATABASE_URL=sqlite://database.db 19 | //! ``` 20 | //! We now can create the database running the following command: 21 | //! ```ignore 22 | //! $ models database create 23 | //! ``` 24 | //! This command will have created an SQLite file called `database.db`. 25 | //! You can now derive the `Model` trait on your structures, 26 | //! and `models` will manage the migrations for you. For example, write at `src/main.rs`: 27 | //! ```rust 28 | //! #![allow(dead_code)] 29 | //! use models::Model; 30 | //! 31 | //! #[derive(Model)] 32 | //! struct Profile { 33 | //! #[primary_key] 34 | //! id: i32, 35 | //! #[unique] 36 | //! email: String, 37 | //! password: String, 38 | //! is_admin: bool, 39 | //! } 40 | //! 41 | //! #[derive(Model)] 42 | //! struct Post { 43 | //! #[primary_key] 44 | //! id: i32, 45 | //! #[foreign_key(Profile.id)] 46 | //! author: String, 47 | //! #[default("")] 48 | //! title: String, 49 | //! content: String, 50 | //! } 51 | //! 52 | //! #[derive(Model)] 53 | //! struct PostLike { 54 | //! #[foreign_key(Profile.id, on_delete="cascade")] 55 | //! #[primary_key(post_id)] 56 | //! profile_id: i32, 57 | //! #[foreign_key(Post.id, on_delete="cascade")] 58 | //! post_id: i32, 59 | //! } 60 | //! 61 | //! #[derive(Model)] 62 | //! struct CommentLike { 63 | //! #[foreign_key(Profile.id)] 64 | //! #[primary_key(comment_id)] 65 | //! profile_id: i32, 66 | //! #[foreign_key(Comment.id)] 67 | //! comment_id: i32, 68 | //! is_dislike: bool, 69 | //! } 70 | //! 71 | //! #[derive(Model)] 72 | //! struct Comment { 73 | //! #[primary_key] 74 | //! id: i32, 75 | //! #[foreign_key(Profile.id)] 76 | //! author: i32, 77 | //! #[foreign_key(Post.id)] 78 | //! post: i32, 79 | //! } 80 | //! fn main() {} 81 | //! ``` 82 | //! 83 | //! If you now run the following command, your migrations should be automatically created. 84 | //! ```ignore 85 | //! $ models generate 86 | //! ``` 87 | //! The output should look like this: 88 | //! ```ignore 89 | //! Generated: migrations/1632280793452 profile 90 | //! Generated: migrations/1632280793459 post 91 | //! Generated: migrations/1632280793465 postlike 92 | //! Generated: migrations/1632280793471 comment 93 | //! Generated: migrations/1632280793476 commentlike 94 | //! ``` 95 | //! You can check out the generated migrations at the `migrations/` folder. 96 | //! To execute these migrations you can execute the following command: 97 | //! ```ignore 98 | //! models migrate run 99 | //! ``` 100 | //! The output should look like this: 101 | //! ``` ignore 102 | //! Applied 1631716729974/migrate profile (342.208µs) 103 | //! Applied 1631716729980/migrate post (255.958µs) 104 | //! Applied 1631716729986/migrate comment (287.792µs) 105 | //! Applied 1631716729993/migrate postlike (349.834µs) 106 | //! Applied 1631716729998/migrate commentlike (374.625µs) 107 | //! ``` 108 | //! If we later modify those structures in our application, we can generate new migrations to update the tables. 109 | //! 110 | //! ## Reverting migration 111 | //! Models can generate down migrations with the `-r` flag. Note that simple and reversible migrations cannot be mixed: 112 | //! ```ignore 113 | //! $ models generate -r 114 | //! ``` 115 | //! In order to revert the last migration executed you can run: 116 | //! ```ignore 117 | //! $ models migrate revert 118 | //! ``` 119 | //! If you later want to see which migrations are yet to be applied you can also excecute: 120 | //! ```ignore 121 | //! $ models migrate info 122 | //! ``` 123 | //! ## Avaibale Attributes 124 | //! ### primary_key 125 | //! It's used to mark the primary key fo the table. 126 | //! ```ignore 127 | //! #[primary_key] 128 | //! id: i32, 129 | //! ``` 130 | //! for tables with multicolumn primary keys, the following syntax is used: 131 | //! ```ignore 132 | //! #[primary_key(second_id)] 133 | //! first_id: i32, 134 | //! second_id: i32, 135 | //! ``` 136 | //! This is equivalent to: 137 | //! ```sql 138 | //! PRIMARY KEY (first_id, second_id), 139 | //! ``` 140 | //! 141 | //! ### foreign_key 142 | //! It is used to mark a foreign key constraint. 143 | //! ```ignore 144 | //! #[foreign_key(Profile.id)] 145 | //! profile: i32, 146 | //! ``` 147 | //! It can also specify `on_delete` and `on_update` constraints: 148 | //! ```ignore 149 | //! #[foreign_key(Profile.id, on_delete="cascade")] 150 | //! profile_id: i32, 151 | //! ``` 152 | //! This is equivalent to: 153 | //! ```sql 154 | //! FOREIGN KEY (profile_id) REFERENCES profile (id) ON DELETE CASCADE, 155 | //! ``` 156 | //! ### default 157 | //! It can be used to set a default value for a column. 158 | //! ```ignore 159 | //! #[default(false)] // when using SQLite use 0 or 1 160 | //! is_admin: bool, 161 | //! #[default("")] 162 | //! text: String, 163 | //! #[default(0)] 164 | //! number: i32, 165 | //! ``` 166 | //! 167 | //! ### unique 168 | //! It is used to mark a unique constraint. 169 | //! ```ignore 170 | //! #[unique] 171 | //! email: String, 172 | //! ``` 173 | //! For multicolumn unique constraints the following syntax is used: 174 | //! ```ignore 175 | //! #[unique(post_id)] 176 | //! profile_id: String, 177 | //! post_id: i32, 178 | //! ``` 179 | //! This is equivalent to: 180 | //! ```sql 181 | //! UNIQUE (profile_id, post_id), 182 | //! ``` 183 | #![allow(unused_imports)] 184 | pub use models_proc_macro::Model; 185 | 186 | #[macro_use] 187 | pub mod error; 188 | mod dialect; 189 | mod prelude; 190 | pub mod private; 191 | #[cfg(tests)] 192 | mod tests; 193 | pub mod types; 194 | 195 | pub use types::*; 196 | -------------------------------------------------------------------------------- /models/src/postgres.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/src/prelude.rs: -------------------------------------------------------------------------------- 1 | pub use crate::error::Error; 2 | pub(crate) use crate::{dialect::Dialect, private::*}; 3 | pub(crate) use convert::{TryFrom, TryInto}; 4 | pub(crate) use models_parser::{ast::*, *}; 5 | pub(crate) use once_cell::sync::Lazy; 6 | pub(crate) use std::{collections::HashMap, sync::Mutex, *}; 7 | pub(crate) use Dialect::*; 8 | pub(crate) type Result = std::result::Result; 9 | pub(crate) use crate::types::IntoSQL; 10 | use url::Url; 11 | 12 | pub(crate) static DATABASE_URL: Lazy = Lazy::new(|| { 13 | let database_url = env::var("DATABASE_URL").unwrap(); 14 | Url::parse(&database_url).unwrap() 15 | }); 16 | pub(crate) static MIGRATIONS_DIR: Lazy = Lazy::new(|| { 17 | let dir = env::var("MIGRATIONS_DIR"); 18 | dir.unwrap() 19 | }); 20 | pub(crate) static DIALECT: Lazy = Lazy::new(|| match DATABASE_URL.scheme() { 21 | "sqlite" => SQLite, 22 | "postgres" => PostgreSQL, 23 | "mysql" => MySQL, 24 | "mssql" => MsSQL, 25 | _ => Any, 26 | }); 27 | #[cfg(feature = "sqlformat")] 28 | use sqlformat::{FormatOptions, Indent}; 29 | #[cfg(feature = "sqlformat")] 30 | pub static FORMAT_OPTIONS: FormatOptions = FormatOptions { 31 | indent: Indent::Spaces(4), 32 | uppercase: true, 33 | lines_between_queries: 2, 34 | }; 35 | 36 | pub static MODELS_GENERATE_DOWN: Lazy = Lazy::new(|| { 37 | let down = env::var("MODELS_GENERATE_DOWN").as_deref() == Ok("true"); 38 | down 39 | }); 40 | 41 | pub(crate) fn parse_sql(sql: &str) -> Result, parser::ParserError> { 42 | let stmts = parser::Parser::parse_sql(&*DIALECT, sql)?; 43 | Ok(stmts) 44 | } 45 | -------------------------------------------------------------------------------- /models/src/private/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module is publicly accessible, but the interface can be subject to changes. 2 | //! This module is intended for macros only. 3 | //! Changes to elements in this module are not considered a breaking change. Do not depend directly on this module. 4 | mod scheduler; 5 | use once_cell::sync::Lazy; 6 | pub(crate) use scheduler::driver::migration::Migration; 7 | pub use scheduler::{ 8 | table::{constraint, Column, Table}, 9 | Scheduler, 10 | }; 11 | 12 | pub trait Model { 13 | fn target() -> Table; 14 | } 15 | 16 | pub static SCHEDULER: Lazy = Lazy::new(Scheduler::new); 17 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/action/mod.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | mod temp_move; 3 | use temp_move::Move; 4 | #[derive(Debug)] 5 | pub(crate) struct Action<'table> { 6 | pub table_name: &'table ObjectName, 7 | pub variant: ActionVariant<'table>, 8 | } 9 | #[derive(Debug)] 10 | pub(crate) enum ActionVariant<'table> { 11 | CreateCol(&'table Column), 12 | 13 | DropCol(Ident), 14 | 15 | CreateConstr(&'table TableConstraint), 16 | 17 | DropConstr(Ident), 18 | 19 | TempMove(Move<'table>), 20 | 21 | CreateTable(&'table Table), 22 | } 23 | 24 | impl<'table> Action<'table> { 25 | pub fn is_fallible(&self) -> bool { 26 | if let ActionVariant::CreateCol(col) = &self.variant { 27 | col.has_default() || col.is_nullable() 28 | } else { 29 | false 30 | } 31 | } 32 | 33 | pub(super) fn create_table(target: &'table Table) -> Self { 34 | Self { 35 | table_name: &target.name, 36 | variant: ActionVariant::CreateTable(target), 37 | } 38 | } 39 | pub(super) fn drop_cons( 40 | name: &'table ObjectName, 41 | cons: &'table TableConstraint, 42 | ) -> Result { 43 | Ok(Self { 44 | table_name: name, 45 | variant: ActionVariant::DropConstr(Ident::new(cons.name()?)), 46 | }) 47 | } 48 | 49 | pub(super) fn drop_col(name: &'table ObjectName, col: &'table Column) -> Self { 50 | Self { 51 | table_name: name, 52 | variant: ActionVariant::DropCol(col.name.clone()), 53 | } 54 | } 55 | pub(super) fn create_column(table_name: &'table ObjectName, col: &'table Column) -> Self { 56 | Self { 57 | table_name, 58 | variant: ActionVariant::CreateCol(col), 59 | } 60 | } 61 | pub(super) fn create_cons(name: &'table ObjectName, cons: &'table TableConstraint) -> Self { 62 | Self { 63 | table_name: name, 64 | variant: ActionVariant::CreateConstr(cons), 65 | } 66 | } 67 | pub fn move_to(old: &'table Table, cols: &ColCRUD<'table>, cons: &ConsCRUD<'table>) -> Self { 68 | let move_ = Move::new(old, cons, cols); 69 | Self { 70 | table_name: &old.name, 71 | variant: ActionVariant::TempMove(move_), 72 | } 73 | } 74 | 75 | pub fn to_statements(self) -> Result> { 76 | use ActionVariant::*; 77 | let mut out = vec![]; 78 | let table_name = self.table_name.clone(); 79 | match self.variant { 80 | TempMove(r#move) => { 81 | return r#move.to_statements(table_name); 82 | } 83 | CreateTable(table) => { 84 | let statement = Statement::from(table.clone()); 85 | out.push(statement); 86 | } 87 | other => { 88 | let operation = match other { 89 | CreateCol(column) => AlterTableOperation::AddColumn { 90 | column_def: ColumnDef::from(column.clone()), 91 | }, 92 | 93 | DropCol(column_name) => AlterTableOperation::DropColumn { 94 | column_name, 95 | if_exists: false, 96 | cascade: DIALECT.supports_cascade(), 97 | }, 98 | DropConstr(name) => AlterTableOperation::DropConstraint { 99 | name, 100 | cascade: DIALECT.supports_cascade(), 101 | restrict: false, 102 | }, 103 | CreateConstr(constr) => AlterTableOperation::DropConstraint { 104 | name: Ident::new(constr.name().unwrap()), 105 | cascade: DIALECT.supports_cascade(), 106 | restrict: false, 107 | }, 108 | 109 | _ => todo!(), 110 | }; 111 | 112 | let statement = Statement::AlterTable(AlterTable { 113 | name: table_name, 114 | operation, 115 | }); 116 | out.push(statement); 117 | } 118 | } 119 | Ok(out) 120 | } 121 | } 122 | 123 | pub fn depends(cons: &TableConstraint, tables: &[&Column]) -> bool { 124 | let names = match cons { 125 | TableConstraint::ForeignKey(fk) => &fk.columns, 126 | TableConstraint::Unique(unique) => &unique.columns, 127 | _ => return false, 128 | }; 129 | let names = names.iter().map(ToString::to_string); 130 | 131 | for col in names { 132 | for table_name in tables.iter().map(|t| t.name().unwrap()) { 133 | if col.to_string() == table_name { 134 | return true; 135 | } 136 | } 137 | } 138 | false 139 | } 140 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/action/temp_move.rs: -------------------------------------------------------------------------------- 1 | use super::{Compare, *}; 2 | use crate::prelude::*; 3 | #[derive(Debug)] 4 | pub(crate) struct Move<'table> { 5 | pub(super) new_cols: Vec<&'table Column>, 6 | pub(super) old_cols: Vec<&'table Column>, 7 | pub(super) constraints: Vec<&'table TableConstraint>, 8 | } 9 | 10 | impl<'table> Move<'table> { 11 | pub fn new(old: &'table Table, cons: &ConsCRUD<'table>, cols: &ColCRUD<'table>) -> Self { 12 | let mut new_cols = vec![]; 13 | let mut old_cols = vec![]; 14 | let mut constraints = vec![]; 15 | for col in &old.columns { 16 | if !cols.to_delete(col) && !cols.to_update(col) { 17 | new_cols.push(col); 18 | old_cols.push(col); 19 | } 20 | } 21 | for &col in &cols.update { 22 | new_cols.push(col); 23 | old_cols.push(col); 24 | } 25 | for con in &old.constraints { 26 | let to_delete = cons.to_delete(con); 27 | let to_update = cons.to_update(con); 28 | if !to_delete && !to_update { 29 | constraints.push(con); 30 | } 31 | } 32 | for con in &cons.update { 33 | if !depends(con, &cols.create) || matches!(*DIALECT, SQLite) { 34 | constraints.push(con); 35 | } 36 | } 37 | for con in &cons.create { 38 | if !depends(con, &cols.create) || matches!(*DIALECT, SQLite) { 39 | constraints.push(con); 40 | } 41 | } 42 | Self { 43 | new_cols, 44 | old_cols, 45 | constraints, 46 | } 47 | } 48 | 49 | pub fn to_statements(self, table_name: ObjectName) -> Result> { 50 | let mut stmt = vec![]; 51 | let create_table = self.create_table(); 52 | let insert = self.insert_statement(table_name.clone())?; 53 | let drop = self.drop_statement(table_name.clone()); 54 | let rename = self.rename(table_name); 55 | stmt.push(create_table); 56 | stmt.push(insert); 57 | stmt.push(drop); 58 | stmt.push(rename); 59 | Ok(stmt) 60 | } 61 | 62 | fn create_table(&self) -> Statement { 63 | Table { 64 | name: ObjectName(vec![Ident::new("temp")]), 65 | columns: self.new_cols.iter().map(|&c| c.clone()).collect(), 66 | constraints: self.constraints.iter().map(|&c| c.clone()).collect(), 67 | if_not_exists: false, 68 | or_replace: false, 69 | } 70 | .into() 71 | } 72 | fn insert_statement(&self, table_name: ObjectName) -> Result { 73 | let new = self 74 | .new_cols 75 | .iter() 76 | .map(|&col| col.ident()) // 77 | .collect(); 78 | let old = self 79 | .old_cols 80 | .iter() 81 | .map(|&col| col.ident()) // 82 | .collect(); 83 | 84 | let insert = format!( 85 | "INSERT INTO temp ({}) SELECT {} FROM {};", 86 | to_string(new), 87 | to_string(old), 88 | table_name 89 | ); 90 | let insert = parse_sql(&insert)? 91 | .into_iter() // 92 | .next() 93 | .unwrap(); 94 | 95 | Ok(insert) 96 | } 97 | 98 | fn drop_statement(&self, table_name: ObjectName) -> Statement { 99 | Statement::Drop(Drop { 100 | object_type: ObjectType::Table, 101 | if_exists: false, 102 | names: vec![table_name], 103 | cascade: !DIALECT.requires_move(), 104 | purge: false, 105 | }) 106 | } 107 | 108 | fn rename(self, table_name: ObjectName) -> Statement { 109 | Statement::AlterTable(AlterTable { 110 | name: ObjectName(vec![Ident::new("temp")]), 111 | operation: AlterTableOperation::RenameTable { 112 | table_name: table_name, 113 | }, 114 | }) 115 | } 116 | } 117 | 118 | fn to_string(collection: Vec) -> String { 119 | let mut out = String::new(); 120 | for (i, c) in collection.iter().enumerate() { 121 | out += &c.to_string(); 122 | if collection.len() != i + 1 { 123 | out += "," 124 | } 125 | } 126 | out 127 | } 128 | 129 | pub fn depends(cons: &TableConstraint, tables: &[&Column]) -> bool { 130 | let names = match cons { 131 | TableConstraint::ForeignKey(fk) => &fk.columns, 132 | TableConstraint::Unique(unique) => &unique.columns, 133 | _ => return false, 134 | }; 135 | let names = names.iter().map(ToString::to_string); 136 | 137 | for col in names { 138 | for table_name in tables.iter().map(|t| t.name().unwrap()) { 139 | if col.to_string() == table_name { 140 | return true; 141 | } 142 | } 143 | } 144 | false 145 | } 146 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/compare.rs: -------------------------------------------------------------------------------- 1 | pub use crate::prelude::*; 2 | pub use collections::HashSet; 3 | 4 | pub(crate) trait Compare: std::fmt::Debug { 5 | fn bodies_are_equal(&self, other: &Self) -> bool; 6 | fn name(&self) -> Result; 7 | fn are_modified(&self, other: &Self) -> bool { 8 | let names = self.names_are_equal(other); 9 | 10 | names && !self.bodies_are_equal(other) 11 | } 12 | fn names_are_equal(&self, other: &Self) -> bool { 13 | let first = match self.name() { 14 | Ok(name) => name, 15 | Err(_) => return false, 16 | }; 17 | let second = match other.name() { 18 | Ok(name) => name, 19 | Err(_) => return false, 20 | }; 21 | 22 | first == second 23 | } 24 | 25 | fn are_equal(&self, other: &Self) -> bool { 26 | self.names_are_equal(other) && self.bodies_are_equal(other) 27 | } 28 | 29 | fn ident(&self) -> Ident; 30 | } 31 | 32 | impl Compare for Column { 33 | fn ident(&self) -> Ident { 34 | self.name.clone() 35 | } 36 | fn name(&self) -> Result { 37 | Ok(self.name.to_string().to_lowercase()) 38 | } 39 | 40 | fn bodies_are_equal(&self, other: &Self) -> bool { 41 | let type1 = &self.r#type; 42 | let type2 = &other.r#type; 43 | 44 | type1 == type2 && { 45 | let h1 = self 46 | .options 47 | .iter() 48 | .map(ToString::to_string) 49 | .map(|string| string.to_lowercase()) 50 | .collect::>(); 51 | let h2 = other 52 | .options 53 | .iter() 54 | .map(ToString::to_string) 55 | .map(|string| string.to_lowercase()) 56 | .collect::>(); 57 | h1 == h2 58 | } 59 | } 60 | } 61 | 62 | impl Compare for TableConstraint { 63 | fn ident(&self) -> Ident { 64 | use TableConstraint::*; 65 | match self { 66 | Unique(ast::Unique { name, .. }) => name, 67 | ForeignKey(ast::ForeignKey { name, .. }) => name, 68 | Check(ast::Check { name, .. }) => name, 69 | } 70 | .clone() 71 | .unwrap() 72 | } 73 | fn name(&self) -> Result { 74 | use TableConstraint::*; 75 | match self { 76 | Unique(ast::Unique { name, .. }) => name, 77 | ForeignKey(ast::ForeignKey { name, .. }) => name, 78 | Check(ast::Check { name, .. }) => name, 79 | } 80 | .as_ref() 81 | .ok_or_else(|| error!("anonymous constraints are not supported.")) 82 | .map(|name| name.to_string().to_lowercase()) 83 | } 84 | 85 | fn bodies_are_equal(&self, other: &Self) -> bool { 86 | use TableConstraint::*; 87 | match (self, other) { 88 | (Unique(u0), Unique(u1)) => { 89 | u0.is_primary == u1.is_primary && { 90 | let cols0 = u0 91 | .columns 92 | .iter() 93 | .map(ToString::to_string) 94 | .map(|str| str.to_lowercase()) 95 | .collect::>(); 96 | let cols1 = u1 97 | .columns 98 | .iter() 99 | .map(ToString::to_string) 100 | .map(|str| str.to_lowercase()) 101 | .collect::>(); 102 | cols0 == cols1 103 | } 104 | } 105 | (ForeignKey(f0), ForeignKey(f1)) => { 106 | f1.on_delete == f0.on_delete 107 | && f1.on_update == f0.on_update 108 | && { 109 | let cols0 = f1 110 | .referred_columns 111 | .iter() 112 | .map(ToString::to_string) 113 | .map(|str| str.to_lowercase()) 114 | .collect::>(); 115 | let cols1 = f0 116 | .referred_columns 117 | .iter() 118 | .map(ToString::to_string) 119 | .map(|str| str.to_lowercase()) 120 | .collect::>(); 121 | cols0 == cols1 122 | } 123 | && { 124 | let name0 = f0.foreign_table.to_string().to_lowercase(); 125 | let name1 = f1.foreign_table.to_string().to_lowercase(); 126 | name0 == name1 127 | } 128 | && { 129 | let cols0 = f0 130 | .columns 131 | .iter() 132 | .map(ToString::to_string) 133 | .map(|str| str.to_lowercase()) 134 | .collect::>(); 135 | let cols1 = f1 136 | .columns 137 | .iter() 138 | .map(ToString::to_string) 139 | .map(|str| str.to_lowercase()) 140 | .collect::>(); 141 | cols0 == cols1 142 | } 143 | } 144 | _ => false, 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/crud.rs: -------------------------------------------------------------------------------- 1 | use super::Compare; 2 | use crate::prelude::*; 3 | #[derive(Debug)] 4 | pub(crate) struct CRUD<'table, T> { 5 | pub create: Vec<&'table T>, 6 | pub delete: Vec<&'table T>, 7 | pub update: Vec<&'table T>, 8 | // pub keep: Vec<&'table T>, 9 | } 10 | 11 | pub(crate) type ColCRUD<'table> = CRUD<'table, Column>; 12 | pub(crate) type ConsCRUD<'table> = CRUD<'table, TableConstraint>; 13 | 14 | impl<'table, T: Compare> CRUD<'table, T> { 15 | pub fn to_delete(&self, obj: &T) -> bool { 16 | self.delete.iter().any(|&del| del.names_are_equal(&obj)) 17 | } 18 | pub fn to_update(&self, obj: &T) -> bool { 19 | self.update.iter().any(|&up| up.names_are_equal(&obj)) 20 | } 21 | pub fn _to_create(&self, obj: &T) -> bool { 22 | self.create.iter().any(|&cr| cr.names_are_equal(&obj)) 23 | } 24 | } 25 | 26 | impl<'table, T: Compare + PartialEq> CRUD<'table, T> { 27 | pub fn new(current: &'table [T], target: &'table [T]) -> Self { 28 | let mut update = vec![]; 29 | let mut delete = vec![]; 30 | let mut create = vec![]; 31 | 32 | for c1 in target { 33 | for c0 in current { 34 | if c1.are_modified(c0) { 35 | update.push(c1); 36 | } 37 | } 38 | 39 | if !current 40 | .iter() 41 | .any(|c0| c0.are_equal(c1) || c0.are_modified(c1)) 42 | { 43 | create.push(c1); 44 | } 45 | } 46 | 47 | for c0 in current { 48 | if target 49 | .iter() 50 | .all(|t| !c0.are_equal(t) && !c0.are_modified(t)) 51 | { 52 | delete.push(c0.clone()); 53 | } 54 | } 55 | 56 | CRUD { 57 | create, 58 | update, 59 | delete, 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/inner.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | pub(super) struct Inner<'table> { 3 | pub table: Option<&'table Table>, 4 | pub target: &'table Table, 5 | } 6 | 7 | impl<'table> Inner<'table> { 8 | pub fn columns(&self) -> ColCRUD<'table> { 9 | let current = &self.table.unwrap().columns; 10 | let target = &self.target.columns; 11 | 12 | CRUD::new(current, target) 13 | } 14 | 15 | pub fn constraints(&self) -> ConsCRUD<'table> { 16 | let current = &self.table.unwrap().constraints; 17 | let target = &self.target.constraints; 18 | CRUD::new(current, target) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/actions/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod action; 2 | mod compare; 3 | mod crud; 4 | mod inner; 5 | 6 | use super::schema::Schema; 7 | use crate::prelude::*; 8 | use action::{depends, Action}; 9 | pub use compare::*; 10 | use crud::*; 11 | 12 | use inner::*; 13 | #[derive(Debug)] 14 | pub(crate) struct Actions<'table> { 15 | name: &'table ObjectName, 16 | actions: Vec>, 17 | } 18 | impl<'table> Actions<'table> { 19 | pub fn new(schema: &'table Schema, target: &'table Table) -> Result { 20 | let table = schema.get_table(&target.name); 21 | 22 | let mut out = Self { 23 | name: &target.name, 24 | actions: vec![], 25 | }; 26 | out.init(Inner { table, target })?; 27 | Ok(out) 28 | } 29 | 30 | fn init(&mut self, inner: Inner<'table>) -> Result<()> { 31 | if inner.table.is_none() { 32 | let action = Action::create_table(inner.target); 33 | self.actions.push(action); 34 | return Ok(()); 35 | } 36 | let columns = inner.columns(); 37 | let constraints = inner.constraints(); 38 | 39 | if move_required(&columns, &constraints) { 40 | self.perform_move(&inner, columns, constraints)?; 41 | } else { 42 | let table_name = &inner.target.name; 43 | for col in columns.delete { 44 | let action = Action::drop_col(table_name, col); 45 | self.actions.push(action); 46 | } 47 | for cons in constraints.delete { 48 | let action = Action::drop_cons(table_name, cons)?; 49 | self.actions.push(action); 50 | } 51 | for cons in &constraints.update { 52 | let action = Action::drop_cons(table_name, cons)?; 53 | self.actions.push(action); 54 | } 55 | 56 | for col in columns.create { 57 | let action = Action::create_column(table_name, col); 58 | self.actions.push(action); 59 | } 60 | for cons in &constraints.create { 61 | let action = Action::create_cons(table_name, cons); 62 | self.actions.push(action); 63 | } 64 | 65 | for cons in &constraints.update { 66 | let action = Action::create_cons(table_name, cons); 67 | self.actions.push(action); 68 | } 69 | } 70 | Ok(()) 71 | } 72 | 73 | fn perform_move( 74 | &mut self, 75 | inner: &Inner<'table>, 76 | cols: ColCRUD<'table>, 77 | cons: ConsCRUD<'table>, 78 | ) -> Result<()> { 79 | // constraints are dropped so they do not conflict 80 | if matches!(*DIALECT, PostgreSQL | MySQL) { 81 | for con in &inner.table.unwrap().constraints { 82 | let drop_cons = Action::drop_cons(&inner.table.unwrap().name, con)?; 83 | self.actions.push(drop_cons); 84 | } 85 | } 86 | let move_action = Action::move_to(inner.table.unwrap(), &cols, &cons); 87 | self.actions.push(move_action); 88 | let table_name = &inner.target.name; 89 | 90 | // moves do not create columns as their names may conflict with constraints. 91 | for &col in &cols.create { 92 | let action = Action::create_column(table_name, col); 93 | self.actions.push(action); 94 | } 95 | // created constraints that could not have been created in move. 96 | // Not all constraints may be created in a move 97 | // because they depended on columns that where not yet created. 98 | // SQLite does not enforce constraints so these are all created 99 | // in the move step 100 | for &cons in &cons.create { 101 | if depends(cons, &cols.create) && !matches!(*DIALECT, SQLite) { 102 | let action = Action::create_cons(table_name, cons); 103 | self.actions.push(action); 104 | } 105 | } 106 | Ok(()) 107 | } 108 | 109 | pub fn as_migrations(self) -> Result> { 110 | let mut migrations = vec![]; 111 | let mut migr = Migration::new(self.name.clone()); 112 | for action in self.actions { 113 | if action.is_fallible() && !migr.is_empty() { 114 | migrations.push(migr); 115 | migr = Migration::new(self.name.clone()) 116 | } 117 | migr.push_up(action)?; 118 | } 119 | migrations.push(migr); 120 | 121 | Ok(migrations) 122 | } 123 | } 124 | pub(crate) fn move_required<'table>(cols: &ColCRUD<'table>, cons: &ConsCRUD<'table>) -> bool { 125 | let sqlite_conditions = DIALECT.requires_move() 126 | && !(cols.update.is_empty() 127 | && cols.delete.is_empty() 128 | && cons.delete.is_empty() 129 | && cons.create.is_empty() 130 | && cons.update.is_empty()); 131 | sqlite_conditions || !cols.update.is_empty() 132 | } 133 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/migration.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | actions::{action::Action, Actions}, 3 | schema::Schema, 4 | Report, 5 | }; 6 | use crate::prelude::*; 7 | use fs::File; 8 | use std::io::Write; 9 | #[derive(Debug)] 10 | pub(crate) struct Migration { 11 | up: Vec, 12 | down: Vec, 13 | name: ObjectName, 14 | } 15 | 16 | fn timestamp() -> u128 { 17 | time::SystemTime::now() 18 | .duration_since(time::UNIX_EPOCH) 19 | .unwrap() 20 | .as_micros() 21 | } 22 | 23 | impl Migration { 24 | pub fn new(name: ObjectName) -> Self { 25 | Self { 26 | up: vec![], 27 | down: vec![], 28 | name, 29 | } 30 | } 31 | 32 | pub fn create_down(&mut self, old: Schema, new: &Schema, table: &ObjectName) -> Result { 33 | if let Some(target) = old.get_table(table) { 34 | let actions = Actions::new(&new, &target)?; 35 | 36 | self.down = actions 37 | .as_migrations()? // 38 | .into_iter() 39 | .map(|mig| mig.up) 40 | .fold(vec![], |mut x, mut y| { 41 | x.append(&mut y); 42 | x 43 | }); 44 | } else { 45 | let drop_stmt = Statement::Drop(Drop { 46 | object_type: ObjectType::Table, 47 | if_exists: false, 48 | names: vec![table.clone()], 49 | cascade: DIALECT.supports_cascade(), 50 | purge: false, 51 | }); 52 | self.down.push(drop_stmt); 53 | } 54 | Ok(()) 55 | } 56 | 57 | pub fn up(&self) -> &[Statement] { 58 | &self.up[..] 59 | } 60 | 61 | pub fn is_empty(&self) -> bool { 62 | self.up.is_empty() 63 | } 64 | pub fn push_up(&mut self, action: Action) -> Result { 65 | let stmts = action.to_statements()?; 66 | self.up.extend(stmts); 67 | Ok(()) 68 | } 69 | 70 | fn write_to_file(file_name: &str, stmts: &[Statement]) -> Result<()> { 71 | let mut file = File::create(file_name)?; 72 | for stmt in stmts { 73 | #[cfg(feature = "sqlformat")] 74 | let stmt = Self::formatted_stmt(stmt); 75 | write!(file, "{};\n\n", stmt)?; 76 | } 77 | Ok(()) 78 | } 79 | 80 | pub fn commit(self) -> Result> { 81 | if self.is_empty() { 82 | return Ok(None); 83 | } 84 | let timestamp = timestamp(); 85 | let file_name = format!("{}/{}_{}", *MIGRATIONS_DIR, timestamp, self.name); 86 | 87 | let name = self.name.to_string().to_lowercase(); 88 | if !*MODELS_GENERATE_DOWN { 89 | let up = format!("{}.sql", file_name); 90 | Self::write_to_file(&up, &self.up)?; 91 | return Ok(Some(Report { timestamp, name })); 92 | } else { 93 | let up = format!("{}.up.sql", file_name); 94 | let down = format!("{}.down.sql", file_name); 95 | Self::write_to_file(&up, &self.up)?; 96 | Self::write_to_file(&down, &self.down)?; 97 | return Ok(Some(Report { timestamp, name })); 98 | }; 99 | } 100 | 101 | #[cfg(feature = "sqlformat")] 102 | fn formatted_stmt(stmt: &Statement) -> String { 103 | use sqlformat::QueryParams; 104 | let stmt = format!("{}", stmt); 105 | sqlformat::format(&stmt, &QueryParams::None, FORMAT_OPTIONS) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | pub(crate) mod actions; 3 | pub mod migration; 4 | mod queue; 5 | mod report; 6 | mod schema; 7 | use actions::Actions; 8 | use queue::*; 9 | 10 | pub(crate) use report::*; 11 | use schema::*; 12 | 13 | pub(crate) struct Driver { 14 | result: Result, 15 | queue: Queue, 16 | success: Vec, 17 | } 18 | 19 | impl Driver { 20 | pub fn new() -> Self { 21 | let result = Schema::new(); 22 | Self { 23 | result, 24 | queue: Queue::new(), 25 | success: vec![], 26 | } 27 | } 28 | pub fn is_first(&self) -> bool { 29 | self.queue.len() == 0 30 | } 31 | 32 | pub fn register(&mut self, table: Table) { 33 | self.queue.insert(table) 34 | } 35 | pub fn as_json(&self) -> String { 36 | let error = if let Err(err) = &self.result { 37 | err.as_json() 38 | } else { 39 | "null".into() 40 | }; 41 | format!( 42 | r#"{{"success": {success:?},"error": {error}}}"#, 43 | success = &self.success, 44 | error = error 45 | ) 46 | } 47 | 48 | pub fn migrate(&mut self) { 49 | self.queue.remove_unregistered(); 50 | loop { 51 | match self.queue.pop() { 52 | Some(target) => self.migrate_table(target), 53 | 54 | None => { 55 | if self.queue.len() != 0 && self.result.is_ok() { 56 | self.result = Err(Error::Cycle(self.queue.remaining_tables())); 57 | } 58 | break; 59 | } 60 | } 61 | } 62 | } 63 | 64 | pub fn migrate_table(&mut self, target: Table) { 65 | if let Err(error) = self.try_migration(target) { 66 | self.result = Err(error); 67 | } 68 | } 69 | 70 | fn try_migration(&mut self, target: Table) -> Result { 71 | let migrations = self.get_migrations(target)?; 72 | for mig in migrations { 73 | if let Some(report) = mig.commit()? { 74 | self.success.push(report); 75 | } 76 | } 77 | Ok(()) 78 | } 79 | 80 | fn get_migrations(&mut self, target: Table) -> Result> { 81 | println!("get_migrations"); 82 | let schema = self.result.as_mut().map_err(|x| x.clone())?; 83 | let actions = Actions::new(&schema, &target)?; 84 | 85 | let mut migrations = actions.as_migrations()?; 86 | 87 | for migr in &mut migrations { 88 | let old_schema = schema.clone(); 89 | for stmt in migr.up() { 90 | schema.update(&stmt)?; 91 | } 92 | migr.create_down(old_schema, schema, &target.name)?; 93 | } 94 | Ok(migrations) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/queue/mod.rs: -------------------------------------------------------------------------------- 1 | mod sorter; 2 | use super::*; 3 | pub use sorter::Sorter; 4 | pub use std::collections::HashSet; 5 | pub(crate) struct Queue { 6 | tables: HashMap, 7 | sorter: Sorter, 8 | } 9 | 10 | impl Queue { 11 | pub fn new() -> Self { 12 | Self { 13 | tables: HashMap::new(), 14 | sorter: Sorter::new(), 15 | } 16 | } 17 | pub fn len(&self) -> usize { 18 | self.tables.len() 19 | } 20 | 21 | pub fn insert(&mut self, table: Table) { 22 | let table_name = &table.name(); 23 | self.tables.insert(table_name.clone(), table.clone()); 24 | self.sorter.insert(table_name.clone()); 25 | for dep in table.deps() { 26 | self.sorter.add_dependency(dep, table_name.clone()) 27 | } 28 | } 29 | 30 | pub fn pop(&mut self) -> Option { 31 | self.sorter 32 | .pop() 33 | .and_then(|value| self.tables.remove(&value)) 34 | } 35 | pub fn remove_unregistered(&mut self) { 36 | self.sorter.remove_unregistered_depedencies() 37 | } 38 | 39 | pub fn remaining_tables(&self) -> Vec { 40 | self.tables.clone().into_iter().map(|(k, _)| k).collect() 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/queue/sorter.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | 3 | pub struct Sorter { 4 | dependencies: HashMap>, 5 | } 6 | 7 | impl Sorter { 8 | pub fn new() -> Self { 9 | Self { 10 | dependencies: HashMap::new(), 11 | } 12 | } 13 | pub fn insert(&mut self, dep: String) { 14 | self.dependencies.insert(dep, HashSet::new()); 15 | } 16 | pub fn add_dependency(&mut self, dep: String, name: String) { 17 | if let Some(deps) = self.dependencies.get_mut(&name) { 18 | deps.insert(dep); 19 | } else { 20 | self.dependencies.insert(dep, HashSet::new()); 21 | } 22 | } 23 | 24 | pub fn remove_unregistered_depedencies(&mut self) { 25 | let mut deps = HashSet::new(); 26 | 27 | for (_, v) in &self.dependencies { 28 | deps.extend(v); 29 | } 30 | deps.retain(|dep| !self.dependencies.contains_key(*dep)); 31 | let deps: HashSet<_> = deps.into_iter().cloned().collect(); 32 | for dep in deps { 33 | self.remove_dep(&dep.clone()); 34 | } 35 | } 36 | 37 | fn remove_dep(&mut self, dep: &str) { 38 | for (_k, v) in &mut self.dependencies { 39 | v.remove(dep); 40 | } 41 | } 42 | pub fn pop(&mut self) -> Option { 43 | let indep = self.find_independent()?; 44 | self.remove_dep(&indep); 45 | Some(indep) 46 | } 47 | 48 | fn find_independent(&mut self) -> Option { 49 | let mut out = None; 50 | for (k, v) in &self.dependencies { 51 | if v.is_empty() { 52 | let key = &k.clone(); 53 | out = self.dependencies.remove_entry(key).map(|x| x.0); 54 | break; 55 | } 56 | } 57 | let out = out?; 58 | self.dependencies.remove(&out); 59 | Some(out) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/report.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | pub(crate) struct Report { 3 | pub timestamp: u128, 4 | pub name: String, 5 | } 6 | 7 | impl fmt::Debug for Report { 8 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 9 | write!(f, r#"[{}, {:?}]"#, self.timestamp, self.name,) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /models/src/private/scheduler/driver/schema.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use fs::*; 3 | 4 | use path::PathBuf; 5 | #[derive(Clone)] 6 | pub struct Schema { 7 | tables: HashMap, 8 | } 9 | 10 | impl Schema { 11 | pub fn new() -> Result { 12 | let mut out = Self { 13 | tables: HashMap::new(), 14 | }; 15 | out.init()?; 16 | Ok(out) 17 | } 18 | #[cfg(test)] 19 | fn _from_sql(sql: &str) -> Result { 20 | let stmts = parse_sql(sql)?; 21 | let mut out = Self { 22 | tables: HashMap::new(), 23 | }; 24 | for stmt in stmts { 25 | out.update(&stmt)?; 26 | } 27 | Ok(out) 28 | } 29 | 30 | pub fn get_table(&self, name: &ObjectName) -> Option<&Table> { 31 | self.tables.get(&name) 32 | } 33 | 34 | pub fn init(&mut self) -> Result { 35 | let stmts = self.get_statements()?; 36 | for stmt in stmts { 37 | self.update(&stmt)?; 38 | } 39 | Ok(()) 40 | } 41 | 42 | fn get_statements(&mut self) -> Result> { 43 | let mut out = vec![]; 44 | for path in self.read_dir()? { 45 | if !is_up_file(&path) { 46 | continue; 47 | } 48 | let sql = read_to_string(&path)?; 49 | let stmts = match parse_sql(&sql) { 50 | Ok(stmts) => stmts, 51 | Err(err) => return Err(Error::SyntaxAtFile(err, path)), 52 | }; 53 | out.extend(stmts); 54 | } 55 | Ok(out) 56 | } 57 | fn read_dir(&self) -> Result> { 58 | let directory = &*MIGRATIONS_DIR; 59 | let mut dir: Vec<_> = read_dir(directory) 60 | .map_err(|_| error!("could not read the \"{}\" directiory.", directory))? 61 | .map(|x| x.unwrap().path()) 62 | .collect(); 63 | dir.sort(); 64 | Ok(dir) 65 | } 66 | 67 | pub fn update(&mut self, stmt: &Statement) -> Result { 68 | use Statement::*; 69 | match stmt { 70 | CreateTable(_) => self.create_table(stmt.clone().try_into().unwrap()), 71 | AlterTable(ast::AlterTable { 72 | name, 73 | operation: AlterTableOperation::RenameTable { table_name }, 74 | }) => self.rename_table(name, table_name), 75 | AlterTable(alter) => self.alter_table(&alter.name, &alter.operation), 76 | Drop(drop) => self.drop_tables(drop), 77 | _ => Ok(()), 78 | } 79 | } 80 | 81 | fn rename_table(&mut self, old_name: &ObjectName, new_name: &ObjectName) -> Result { 82 | let mut table = self.tables.remove(&old_name).ok_or_else(|| { 83 | error!( 84 | "attempt to rename table {:?} to {:?}, but it does not exist", 85 | &old_name, &new_name 86 | ) 87 | })?; 88 | if !DIALECT.requires_move() { 89 | self.cascade(&old_name); 90 | } 91 | table.name = new_name.clone(); 92 | self.tables.insert(new_name.clone(), table); 93 | Ok(()) 94 | } 95 | 96 | fn cascade(&mut self, name: &ObjectName) { 97 | use TableConstraint::*; 98 | self.tables // 99 | .values_mut() 100 | .for_each(|table| { 101 | table.constraints = table 102 | .constraints 103 | .drain(..) 104 | .filter(|constr| match constr { 105 | ForeignKey(ast::ForeignKey { foreign_table, .. }) => foreign_table == name, 106 | _ => true, 107 | }) 108 | .collect() 109 | }); 110 | } 111 | 112 | // pub(crate) fn get_changes(&self, target: Table) -> Result { 113 | // if let Some(table) = self.tables.get(&target.name) { 114 | // table.get_changes(&target)? 115 | // } else { 116 | // vec![target.clone().into()] 117 | // } 118 | // } 119 | 120 | fn drop_tables(&mut self, drop: &ast::Drop) -> Result { 121 | for name in drop.names.iter() { 122 | if !drop.if_exists && !self.tables.contains_key(name) { 123 | return Err(error!( 124 | "failed to load migrations. Table \"{}\" cannot be dropped as it does not exist.", 125 | name 126 | )); 127 | } 128 | if drop.cascade { 129 | self.cascade(name); 130 | } 131 | self.tables.remove(name); 132 | } 133 | Ok(()) 134 | } 135 | fn alter_table(&mut self, name: &ObjectName, op: &AlterTableOperation) -> Result { 136 | self.tables 137 | .get_mut(&name) // 138 | .map(|table| table.alter_table(op)) 139 | .ok_or_else(|| { 140 | error!( 141 | "failed to load migrations. Could not find the table \"{}\"", 142 | name 143 | ) 144 | })??; 145 | Ok(()) 146 | } 147 | fn create_table(&mut self, table: Table) -> Result { 148 | let table = table; 149 | let tables = &mut self.tables; 150 | if !table.if_not_exists && tables.contains_key(&table.name) && !table.or_replace { 151 | return Err(error!( 152 | "attempting to create table \"{}\", but it already exists.", 153 | table.name 154 | )); 155 | } 156 | tables.insert(table.name.clone(), table); 157 | Ok(()) 158 | } 159 | } 160 | 161 | fn is_up_file(file_name: &PathBuf) -> bool { 162 | file_name.is_file() && !file_name.to_str().unwrap().contains(".down.sql") 163 | } 164 | -------------------------------------------------------------------------------- /models/src/private/scheduler/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | pub mod driver; 3 | pub mod table; 4 | 5 | use table::*; 6 | 7 | use driver::*; 8 | pub struct Scheduler(Mutex); 9 | 10 | impl Scheduler { 11 | pub(super) fn new() -> Self { 12 | Self(Mutex::new(Driver::new())) 13 | } 14 | 15 | pub fn register(&self, table: Table) { 16 | let is_first; 17 | { 18 | let mut driver = self.0.lock().unwrap(); 19 | is_first = driver.is_first(); 20 | driver.register(table) 21 | // release the lock 22 | } 23 | 24 | if is_first { 25 | std::thread::sleep(time::Duration::from_millis(250)); 26 | self.commit() 27 | } 28 | } 29 | fn commit(&self) { 30 | let mut driver = self.0.lock().unwrap(); 31 | driver.migrate(); 32 | let json = driver.as_json(); 33 | println!("{0}", json); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /models/src/private/scheduler/table/column.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use models_parser::{dialect::*, parser::*}; 3 | #[derive(Clone, Debug, PartialEq)] 4 | pub struct Column { 5 | pub name: Ident, 6 | pub r#type: DataType, 7 | pub options: Vec, 8 | } 9 | 10 | impl Column { 11 | pub fn new(name: &str, r#type: DataType, is_nullable: bool) -> Self { 12 | let options; 13 | if !is_nullable { 14 | options = vec![ColumnOptionDef { 15 | name: None, 16 | option: ColumnOption::NotNull, 17 | }]; 18 | } else { 19 | options = vec![] 20 | } 21 | 22 | Column { 23 | name: Ident::new(name.to_lowercase()), 24 | r#type, 25 | options, 26 | } 27 | } 28 | 29 | pub fn new_with_default(name: &str, r#type: DataType, is_nullable: bool, def: &str) -> Self { 30 | let dialect = GenericDialect {}; 31 | let mut tokens = tokenizer::Tokenizer::new(&dialect, def); 32 | let mut parser = Parser::new(tokens.tokenize().unwrap(), &dialect); 33 | let expr = parser.parse_expr().unwrap(); 34 | 35 | let mut col = Column { 36 | name: Ident::new(name.to_lowercase()), 37 | r#type, 38 | options: vec![ast::ColumnOptionDef { 39 | name: None, 40 | option: ast::ColumnOption::Default(expr), 41 | }], 42 | }; 43 | if !is_nullable { 44 | col.options.push(ColumnOptionDef { 45 | name: None, 46 | option: ColumnOption::NotNull, 47 | }); 48 | }; 49 | col 50 | } 51 | 52 | pub fn has_default(&self) -> bool { 53 | for option in &self.options { 54 | if matches!(option.option, ColumnOption::Default(_)) { 55 | return true; 56 | } 57 | } 58 | false 59 | } 60 | 61 | pub fn is_nullable(&self) -> bool { 62 | for option in &self.options { 63 | if matches!(option.option, ColumnOption::NotNull) { 64 | return false; 65 | } 66 | } 67 | true 68 | } 69 | } 70 | 71 | impl From for Column { 72 | fn from(col: ColumnDef) -> Self { 73 | Column { 74 | name: col.name, 75 | options: col.options, 76 | r#type: col.data_type, 77 | } 78 | } 79 | } 80 | 81 | impl From for ColumnDef { 82 | fn from(col: Column) -> Self { 83 | ColumnDef { 84 | name: col.name, 85 | options: col.options, 86 | data_type: col.r#type, 87 | collation: None, 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /models/src/private/scheduler/table/constraint.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use TableConstraint::*; 3 | 4 | pub fn name(constr: &TableConstraint) -> &Option { 5 | match constr { 6 | Unique(ast::Unique { name, .. }) => name, 7 | ForeignKey(ast::ForeignKey { name, .. }) => name, 8 | Check(ast::Check { name, .. }) => name, 9 | } 10 | } 11 | 12 | pub fn primary(name: &str, fields: &[&str]) -> TableConstraint { 13 | let name = Some(Ident::new(name)); 14 | let mut columns = vec![]; 15 | for field in fields { 16 | columns.push(Ident::new(*field)); 17 | } 18 | Unique(ast::Unique { 19 | name, 20 | columns, 21 | is_primary: true, 22 | }) 23 | } 24 | 25 | pub fn unique(name: &str, fields: &[&str]) -> TableConstraint { 26 | let name = Some(Ident::new(name)); 27 | let mut columns = vec![]; 28 | for field in fields { 29 | columns.push(Ident::new(*field)); 30 | } 31 | Unique(ast::Unique { 32 | name, 33 | columns, 34 | is_primary: false, 35 | }) 36 | } 37 | 38 | pub fn foreign_key( 39 | name: &str, 40 | local_col: &str, 41 | foreign_table: &str, 42 | foreign_col: &str, 43 | on_delete: &str, 44 | on_update: &str, 45 | ) -> TableConstraint { 46 | ForeignKey(ast::ForeignKey { 47 | name: Some(Ident::new(name)), 48 | foreign_table: ObjectName(vec![Ident::new(foreign_table)]), 49 | referred_columns: vec![Ident::new(foreign_col)], 50 | columns: vec![Ident::new(local_col)], 51 | on_delete: match &*on_delete.to_lowercase() { 52 | "cascade" => Some(ast::ReferentialAction::Cascade), 53 | "no action" => Some(ast::ReferentialAction::NoAction), 54 | "restrict" => Some(ast::ReferentialAction::Restrict), 55 | "set default" => Some(ast::ReferentialAction::SetDefault), 56 | "set null" => Some(ast::ReferentialAction::SetNull), 57 | _ => None, 58 | }, 59 | on_update: match &*on_update.to_lowercase() { 60 | "cascade" => Some(ast::ReferentialAction::Cascade), 61 | "no action" => Some(ast::ReferentialAction::NoAction), 62 | "restrict" => Some(ast::ReferentialAction::Restrict), 63 | "set default" => Some(ast::ReferentialAction::SetDefault), 64 | "set null" => Some(ast::ReferentialAction::SetNull), 65 | _ => None, 66 | }, 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /models/src/private/scheduler/table/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | mod column; 3 | pub mod constraint; 4 | use crate::private::scheduler::driver::actions::Compare; 5 | pub use column::*; 6 | 7 | #[derive(Clone, Debug)] 8 | pub struct Table { 9 | pub(crate) name: ObjectName, 10 | pub if_not_exists: bool, 11 | pub or_replace: bool, 12 | pub columns: Vec, 13 | pub constraints: Vec, 14 | } 15 | 16 | impl Table { 17 | pub fn new(name: &str) -> Self { 18 | Table { 19 | name: ObjectName(vec![Ident::new(name)]), 20 | columns: vec![], 21 | constraints: vec![], 22 | if_not_exists: false, 23 | or_replace: false, 24 | } 25 | } 26 | 27 | pub(crate) fn name(&self) -> String { 28 | self.name.to_string().to_lowercase() 29 | } 30 | /// returns depenedencies of the table 31 | pub(crate) fn deps(&self) -> Vec { 32 | self.constraints 33 | .iter() 34 | .filter_map(|constr| match constr { 35 | TableConstraint::ForeignKey(ForeignKey { foreign_table, .. }) => { 36 | Some(foreign_table.to_string().to_lowercase()) 37 | } 38 | _ => None, 39 | }) 40 | .collect() 41 | } 42 | 43 | pub(super) fn alter_table(&mut self, op: &AlterTableOperation) -> Result { 44 | use AlterTableOperation::*; 45 | match op { 46 | AddColumn { column_def } => self.columns.push(column_def.clone().into()), 47 | AddConstraint(constr) => self.constraints.push(constr.clone()), 48 | DropConstraint { name, .. } => self.drop_constraint(name.to_string()), 49 | 50 | DropColumn { 51 | column_name, 52 | if_exists, 53 | .. 54 | } => self.drop_col(column_name, *if_exists), 55 | RenameColumn { 56 | old_column_name, 57 | new_column_name, 58 | } => self.rename_col(old_column_name, new_column_name), 59 | op => return Err(error!("unsupported operation: \"{}\"", op)), 60 | } 61 | Ok(()) 62 | } 63 | 64 | pub(super) fn drop_col(&mut self, name: &Ident, if_exists: bool) { 65 | let len = self.columns.len(); 66 | self.columns = self 67 | .columns 68 | .drain(..) 69 | .filter(|col| &col.name != name) 70 | .collect(); 71 | assert!( 72 | len != self.columns.len() || if_exists, 73 | "Column \"{}\" does not exists", 74 | name 75 | ); 76 | } 77 | 78 | pub fn drop_constraint(&mut self, rm_name: String) { 79 | self.constraints = self 80 | .constraints 81 | .drain(..) 82 | .filter(|constr| constr.name().ok().as_ref() != Some(&rm_name)) 83 | .collect(); 84 | } 85 | pub(super) fn rename_col(&mut self, old: &Ident, new: &Ident) { 86 | self.columns = self 87 | .columns 88 | .iter() 89 | .map(Clone::clone) 90 | .map(|mut col| { 91 | if &col.name == old { 92 | col.name = new.clone(); 93 | } 94 | col 95 | }) 96 | .collect() 97 | } 98 | } 99 | 100 | impl TryFrom for Table { 101 | type Error = Error; 102 | fn try_from(value: Statement) -> Result { 103 | if let Statement::CreateTable(table) = value { 104 | let name = table 105 | .name 106 | .0 107 | .into_iter() 108 | .map(|ident| ident.value.to_lowercase()) 109 | .map(Ident::new) 110 | .collect(); 111 | Ok(Table { 112 | name: ObjectName(name), 113 | if_not_exists: false, 114 | or_replace: false, 115 | columns: table.columns.into_iter().map(Into::into).collect(), 116 | constraints: table.constraints, 117 | }) 118 | } else { 119 | Err(error!( 120 | "Expected a \"CREATE TABLE\" statement, found {}", 121 | value 122 | )) 123 | } 124 | } 125 | } 126 | 127 | impl From
for Statement { 128 | fn from(table: Table) -> Self { 129 | Statement::CreateTable(Box::new(ast::CreateTable { 130 | or_replace: false, 131 | temporary: false, 132 | external: false, 133 | if_not_exists: false, 134 | name: table.name, 135 | columns: table.columns.into_iter().map(Into::into).collect(), 136 | constraints: table.constraints, 137 | hive_distribution: HiveDistributionStyle::NONE, 138 | hive_formats: None, 139 | table_properties: vec![], 140 | with_options: vec![], 141 | file_format: None, 142 | location: None, 143 | query: None, 144 | without_rowid: false, 145 | like: None, 146 | })) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /models/src/rusqlite.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/src/sqlx.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/src/tests/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::private::{driver::schema::Schema, scheduler, *}; 2 | -------------------------------------------------------------------------------- /models/src/tokio_postgres.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/src/types/bytes.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tvallotton/models/87092ddd62492e8c5aa6be5a07f9bcfbc1b9ed84/models/src/types/bytes.rs -------------------------------------------------------------------------------- /models/src/types/chrono_impl.rs: -------------------------------------------------------------------------------- 1 | // | Rust type | MySQL | Postgres | SQLite | 2 | // |-------------------------------|-------------------------|--------------------|--------------------| 3 | // | `chrono::DateTime` | TIMESTAMP | TIMESTAMPTZ | DATETIME | 4 | // | `chrono::DateTime` | TIMESTAMP | TIMESTAMPTZ | DATETIME | 5 | // | `chrono::NaiveDateTime` | DATETIME | TIMESTAMP | DATETIME | 6 | // | `chrono::NaiveDate` | DATE | DATE | DATETIME | 7 | // | `chrono::NaiveTime` | TIME | TIME | DATETIME | 8 | // 9 | use super::*; 10 | use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, Utc}; 11 | use models_parser::ast::DataType; 12 | 13 | impl IntoSQL for DateTime { 14 | fn into_sql() -> DataType { 15 | match *DIALECT { 16 | PostgreSQL => DataType::custom("TIMESTAMPTZ"), 17 | SQLite => DataType::custom("DATETIME"), 18 | _ => DataType::Timestamp, 19 | } 20 | } 21 | } 22 | impl IntoSQL for DateTime { 23 | fn into_sql() -> DataType { 24 | match *DIALECT { 25 | PostgreSQL => DataType::custom("TIMESTAMPTZ"), 26 | SQLite => DataType::custom("DATETIME"), 27 | _ => DataType::Timestamp, 28 | } 29 | } 30 | } 31 | 32 | impl IntoSQL for NaiveDateTime { 33 | fn into_sql() -> DataType { 34 | match *DIALECT { 35 | PostgreSQL => DataType::Timestamp, 36 | _ => DataType::custom("DATETIME"), 37 | } 38 | } 39 | } 40 | 41 | impl IntoSQL for NaiveDate { 42 | fn into_sql() -> DataType { 43 | match *DIALECT { 44 | SQLite => DataType::custom("DATETIME"), 45 | _ => DataType::Date, 46 | } 47 | } 48 | } 49 | 50 | impl IntoSQL for NaiveTime { 51 | fn into_sql() -> DataType { 52 | match *DIALECT { 53 | SQLite => DataType::custom("DATETIME"), 54 | _ => DataType::Time, 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /models/src/types/json.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use models_parser::ast::DataType; 3 | use serde::*; 4 | use std::ops::{Deref, DerefMut}; 5 | 6 | /// Wrapper type used to hold serilizable data. The type generated is `JSON`. 7 | /// ```rust 8 | /// struct Author { 9 | /// books: Json> 10 | /// } 11 | /// ``` 12 | /// The previous structure would generate: 13 | /// ```sql 14 | /// CREATE TABLE author ( 15 | /// books JSON NOT NULL, 16 | /// ); 17 | /// ``` 18 | 19 | #[derive(Serialize, Deserialize, Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] 20 | pub struct Json(pub T); 21 | 22 | impl Deref for Json { 23 | type Target = T; 24 | fn deref(&self) -> &Self::Target { 25 | &self.0 26 | } 27 | } 28 | 29 | impl DerefMut for Json { 30 | fn deref_mut(&mut self) -> &mut Self::Target { 31 | &mut self.0 32 | } 33 | } 34 | 35 | impl AsRef for Json { 36 | fn as_ref(&self) -> &T { 37 | &self.0 38 | } 39 | } 40 | 41 | impl AsMut for Json { 42 | fn as_mut(&mut self) -> &mut T { 43 | &mut self.0 44 | } 45 | } 46 | 47 | impl IntoSQL for Json { 48 | const IS_NULLABLE: bool = false; 49 | fn into_sql() -> DataType { 50 | DataType::Json 51 | } 52 | } 53 | #[allow(unused_imports)] 54 | #[cfg(feature = "sqlx")] 55 | mod sqlx_impl { 56 | use super::*; 57 | use serde::{Deserialize, Serialize}; 58 | #[cfg(feature = "sqlx-mysql")] 59 | use sqlx::sqlite::{Sqlite, SqliteTypeInfo}; 60 | #[cfg(feature = "sqlx-mysql")] 61 | use sqlx::mysql::{MySql, MySqlTypeInfo}; 62 | #[cfg(feature = "sqlx-postgres")] 63 | use sqlx::postgres::{PgTypeInfo, Postgres}; 64 | use sqlx::{ 65 | database::{HasArguments, HasValueRef}, 66 | decode::Decode, 67 | encode::{Encode, IsNull}, 68 | Database, Type, 69 | }; 70 | use std::io::Write; 71 | #[cfg(feature = "sqlx-postgres")] 72 | impl Type for Json 73 | where 74 | DB: Database, 75 | sqlx::types::Json: Type, 76 | { 77 | fn type_info() -> DB::TypeInfo { 78 | sqlx::types::Json::type_info() 79 | } 80 | 81 | fn compatible(ty: &DB::TypeInfo) -> bool { 82 | sqlx::types::Json::compatible(ty) 83 | } 84 | } 85 | impl<'q, T, DB> Encode<'q, DB> for Json 86 | where 87 | DB: Database, 88 | T: Serialize, 89 | >::ArgumentBuffer: Write, 90 | { 91 | fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { 92 | serde_json::to_writer(buf, self).ok(); 93 | IsNull::No 94 | } 95 | } 96 | 97 | impl<'r, DB, T> Decode<'r, DB> for Json 98 | where 99 | &'r str: Decode<'r, DB>, 100 | DB: Database, 101 | T: Deserialize<'r>, 102 | { 103 | fn decode( 104 | value: >::ValueRef, 105 | ) -> Result, Box> { 106 | let string_value = <&str as Decode>::decode(value)?; 107 | serde_json::from_str(string_value) 108 | .map(Json) 109 | .map_err(Into::into) 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /models/src/types/mod.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! # Types 3 | //! 4 | //! | Rust | PostgreSQL | MySQL | SQLite | 5 | //! |------------ |---------------|--------------------------|---------------------| 6 | //! | `bool` | BOOLEAN | BOOLEAN | BOOLEAN | 7 | //! | `i8` | SMALLINT | TINYINT | INTEGER | 8 | //! | `i16` | SMALLINT | SMALLINT | INTEGER | 9 | //! | `i32` | INT | INT | INTEGER | 10 | //! | `i64` | BIGINT | BIGINT | INTEGER | 11 | //! | `f32` | REAL | FLOAT | REAL | 12 | //! | `f64` | REAL | REAL | REAL | 13 | //! | `String` | TEXT | TEXT | TEXT | 14 | //! | `VarChar` | VARCHAR(SIZE) | VARCHAR(SIZE) | TEXT | 15 | //! | `VarBinary`| BYTEA | VARBINARY(SIZE) | BLOB | 16 | //! | `Vec` | BYTEA | BLOB | BLOB | 17 | //! | `[u8; SIZE]` | BYTEA | BLOB(SIZE) | BLOB | 18 | //! | 19 | //! 20 | //! ### [`chrono`](https://crates.io/crates/chrono) 21 | //! 22 | //! Requires the `chrono` Cargo feature flag. 23 | //! 24 | //! | Rust type | Postgres | MySQL | SQLite | 25 | //! |------------------------------|------------------|------------------|--------------------| 26 | //! | `chrono::DateTime` | TIMESTAMPTZ | TIMESTAMP | DATETIME | 27 | //! | `chrono::DateTime` | TIMESTAMPTZ | TIMESTAMP | DATETIME | 28 | //! | `chrono::NaiveDateTime` | TIMESTAMP | DATETIME | DATETIME | 29 | //! | `chrono::NaiveDate` | DATE | DATE | DATETIME | 30 | //! | `chrono::NaiveTime` | TIME | TIME | DATETIME | 31 | //! 32 | #[cfg(feature = "chrono")] 33 | mod chrono_impl; 34 | #[cfg(feature = "json")] 35 | mod json; 36 | mod serial; 37 | mod time; 38 | mod var_binary; 39 | mod var_char; 40 | 41 | #[cfg(feature = "json")] 42 | pub use json::*; 43 | use models_parser::ast::DataType; 44 | pub use serial::Serial; 45 | pub use time::*; 46 | pub use var_binary::VarBinary; 47 | pub use var_char::VarChar; 48 | 49 | use crate::prelude::*; 50 | 51 | /// Do not use this trait in your production code. 52 | /// Its intended use is for migration generation only. 53 | /// It will panic if used outside its intended API. 54 | pub trait IntoSQL { 55 | fn into_sql() -> DataType; 56 | const IS_NULLABLE: bool = false; 57 | } 58 | 59 | impl IntoSQL for i32 { 60 | fn into_sql() -> DataType { 61 | DataType::Int(None) 62 | } 63 | } 64 | impl IntoSQL for i16 { 65 | fn into_sql() -> DataType { 66 | match *DIALECT { 67 | SQLite => DataType::Int(None), 68 | PostgreSQL => DataType::SmallInt(None), 69 | _ => DataType::SmallInt(None), 70 | } 71 | } 72 | } 73 | impl IntoSQL for i8 { 74 | fn into_sql() -> DataType { 75 | match *DIALECT { 76 | SQLite => DataType::Int(None), 77 | PostgreSQL => DataType::SmallInt(None), 78 | _ => DataType::TinyInt(None), 79 | } 80 | } 81 | } 82 | 83 | impl IntoSQL for u32 { 84 | fn into_sql() -> DataType { 85 | match *DIALECT { 86 | MySQL => DataType::BigInt(None), 87 | PostgreSQL => DataType::BigInt(None), 88 | _ => DataType::Int(None), 89 | } 90 | } 91 | } 92 | // impl IntoSQL for u16 { 93 | // fn into_sql() -> DataType { 94 | // match *DIALECT { 95 | // MySQL => DataType::Int(None), 96 | // _ => DataType::Int(None), 97 | // } 98 | // } 99 | // } 100 | // impl IntoSQL for u8 { 101 | // fn into_sql() -> DataType { 102 | // match *DIALECT { 103 | // MySQL => DataType::Int(None), 104 | // PostgreSQL => DataType::custom("SMALLINT"), 105 | // _ => DataType::Int(None), 106 | // } 107 | // } 108 | // } 109 | 110 | // impl IntoSQL for u64 { 111 | // fn into_sql() -> DataType { 112 | // DataType::BigInt(None) 113 | // } 114 | // } 115 | impl IntoSQL for i64 { 116 | fn into_sql() -> DataType { 117 | match *DIALECT { 118 | SQLite => DataType::Int(None), 119 | _ => DataType::BigInt(None), 120 | } 121 | } 122 | } 123 | impl IntoSQL for f64 { 124 | fn into_sql() -> DataType { 125 | match *DIALECT { 126 | PostgreSQL => DataType::Double, 127 | _ => DataType::Real, 128 | } 129 | } 130 | } 131 | impl IntoSQL for f32 { 132 | fn into_sql() -> DataType { 133 | match *DIALECT { 134 | MySQL => DataType::Real, 135 | _ => DataType::Real, 136 | } 137 | } 138 | } 139 | 140 | impl IntoSQL for String { 141 | fn into_sql() -> DataType { 142 | DataType::Text 143 | } 144 | } 145 | impl IntoSQL for [u8; N] { 146 | fn into_sql() -> DataType { 147 | match *DIALECT { 148 | PostgreSQL => DataType::Bytea, 149 | SQLite => DataType::Blob(None), 150 | _ => DataType::Blob(Some(N as u64)), 151 | } 152 | } 153 | } 154 | impl IntoSQL for Vec { 155 | fn into_sql() -> DataType { 156 | match *DIALECT { 157 | PostgreSQL => DataType::Bytea, 158 | _ => DataType::Blob(None), 159 | } 160 | } 161 | } 162 | 163 | impl IntoSQL for Option { 164 | fn into_sql() -> DataType { 165 | T::into_sql() 166 | } 167 | const IS_NULLABLE: bool = false; 168 | } 169 | impl IntoSQL for bool { 170 | fn into_sql() -> DataType { 171 | DataType::Boolean 172 | } 173 | } 174 | 175 | #[test] 176 | fn func() { 177 | let x = &models_parser::parser::Parser::parse_sql( 178 | &models_parser::dialect::GenericDialect {}, 179 | " 180 | 181 | CREATE TABLE Persons ( 182 | Personid int NOT NULL AUTO_INCREMENT, 183 | LastName varchar(255) NOT NULL, 184 | FirstName varchar(255), 185 | Age int, 186 | PRIMARY KEY (Personid) 187 | ); 188 | ", 189 | ) 190 | .unwrap()[0]; 191 | 192 | println!("{}", x); 193 | } 194 | -------------------------------------------------------------------------------- /models/src/types/serial.rs: -------------------------------------------------------------------------------- 1 | use models_parser::ast::DataType; 2 | #[cfg(feature = "serde")] 3 | use serde::*; 4 | use std::ops::{Deref, DerefMut}; 5 | 6 | use crate::prelude::*; 7 | 8 | 9 | /// PostgreSQL `SERIAL` type. It enables autoincrementing functionality. 10 | /// Example: 11 | /// ``` 12 | /// struct Profile { 13 | /// id: Serial, 14 | /// } 15 | /// ``` 16 | /// The previus structure would generate: 17 | /// ```sql 18 | /// CREATE TABLE profile ( 19 | /// id SERIAL NOT NULL 20 | /// ); 21 | /// ``` 22 | /// 23 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 24 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 25 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 26 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] 27 | pub struct Serial(pub i32); 28 | 29 | impl From for Serial 30 | where 31 | T: Into, 32 | { 33 | fn from(obj: T) -> Self { 34 | Self(obj.into()) 35 | } 36 | } 37 | 38 | impl Deref for Serial { 39 | type Target = i32; 40 | fn deref(&self) -> &Self::Target { 41 | &self.0 42 | } 43 | } 44 | 45 | impl DerefMut for Serial { 46 | fn deref_mut(&mut self) -> &mut Self::Target { 47 | &mut self.0 48 | } 49 | } 50 | 51 | impl AsMut for Serial { 52 | fn as_mut(&mut self) -> &mut i32 { 53 | &mut self.0 54 | } 55 | } 56 | 57 | impl AsRef for Serial { 58 | fn as_ref(&self) -> &i32 { 59 | &self.0 60 | } 61 | } 62 | 63 | impl IntoSQL for Serial { 64 | fn into_sql() -> DataType { 65 | DataType::Serial 66 | } 67 | } 68 | 69 | 70 | -------------------------------------------------------------------------------- /models/src/types/time.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | #[cfg(feature = "serde")] 3 | use serde::*; 4 | use std::ops::{Deref, DerefMut}; 5 | 6 | 7 | /// Wrapper type that defaults to `DATE`. 8 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 9 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 10 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 11 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] 12 | pub struct Date(pub T); 13 | impl Deref for Date { 14 | type Target = T; 15 | fn deref(&self) -> &Self::Target { 16 | &self.0 17 | } 18 | } 19 | impl DerefMut for Date { 20 | fn deref_mut(&mut self) -> &mut Self::Target { 21 | &mut self.0 22 | } 23 | } 24 | impl AsRef for Date { 25 | fn as_ref(&self) -> &T { 26 | &self.0 27 | } 28 | } 29 | impl AsMut for Date { 30 | fn as_mut(&mut self) -> &mut T { 31 | &mut self.0 32 | } 33 | } 34 | impl IntoSQL for Date { 35 | const IS_NULLABLE: bool = false; 36 | fn into_sql() -> DataType { 37 | DataType::Date 38 | } 39 | } 40 | /// Wrapper type that defaults to `DATETIME`. 41 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 42 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 43 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 44 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] 45 | pub struct DateTime(pub T); 46 | impl Deref for DateTime { 47 | type Target = T; 48 | fn deref(&self) -> &Self::Target { 49 | &self.0 50 | } 51 | } 52 | impl DerefMut for DateTime { 53 | fn deref_mut(&mut self) -> &mut Self::Target { 54 | &mut self.0 55 | } 56 | } 57 | impl AsRef for DateTime { 58 | fn as_ref(&self) -> &T { 59 | &self.0 60 | } 61 | } 62 | impl AsMut for DateTime { 63 | fn as_mut(&mut self) -> &mut T { 64 | &mut self.0 65 | } 66 | } 67 | impl IntoSQL for DateTime { 68 | const IS_NULLABLE: bool = false; 69 | fn into_sql() -> DataType { 70 | DataType::custom("DATETIME") 71 | } 72 | } 73 | /// Wrapper type that defaults to `TIMESTAMP`. 74 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 75 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 76 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 77 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] 78 | pub struct Timestamp(pub T); 79 | impl Deref for Timestamp { 80 | type Target = T; 81 | fn deref(&self) -> &Self::Target { 82 | &self.0 83 | } 84 | } 85 | impl DerefMut for Timestamp { 86 | fn deref_mut(&mut self) -> &mut Self::Target { 87 | &mut self.0 88 | } 89 | } 90 | impl AsRef for Timestamp { 91 | fn as_ref(&self) -> &T { 92 | &self.0 93 | } 94 | } 95 | impl AsMut for Timestamp { 96 | fn as_mut(&mut self) -> &mut T { 97 | &mut self.0 98 | } 99 | } 100 | impl IntoSQL for Timestamp { 101 | const IS_NULLABLE: bool = false; 102 | fn into_sql() -> DataType { 103 | DataType::Timestamp 104 | } 105 | } -------------------------------------------------------------------------------- /models/src/types/var_binary.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use models_parser::ast::DataType; 3 | use std::{ 4 | convert::AsMut, 5 | ops::{Deref, DerefMut}, 6 | }; 7 | 8 | #[cfg(feature = "serde")] 9 | use serde::*; 10 | 11 | /// Used for MySQL when to specify that the datatype should be 12 | /// a `VARBINARY(N)`. The database will make sure the field does not 13 | /// go over the specified length. 14 | /// ``` 15 | /// use models::{Model, VarChar}; 16 | /// #[derive(Model)] 17 | /// struct Example { 18 | /// bin_data: VarBinary<255> 19 | /// } 20 | /// ``` 21 | /// The previous structure would generate: 22 | /// ```sql 23 | /// CREATE TABLE example ( 24 | /// bin_data VarBinary(255) NOT NULL 25 | /// ); 26 | /// ``` 27 | 28 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 29 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 30 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 31 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] 32 | pub struct VarBinary(pub Vec); 33 | 34 | impl VarBinary { 35 | pub fn new() -> Self { 36 | Self::default() 37 | } 38 | } 39 | 40 | impl Deref for VarBinary { 41 | type Target = Vec; 42 | fn deref(&self) -> &Self::Target { 43 | &self.0 44 | } 45 | } 46 | 47 | impl DerefMut for VarBinary { 48 | fn deref_mut(&mut self) -> &mut Self::Target { 49 | &mut self.0 50 | } 51 | } 52 | 53 | impl AsRef> for VarBinary { 54 | fn as_ref(&self) -> &Vec { 55 | &self.0 56 | } 57 | } 58 | 59 | impl AsMut> for VarBinary { 60 | fn as_mut(&mut self) -> &mut Vec { 61 | &mut self.0 62 | } 63 | } 64 | 65 | impl IntoSQL for VarBinary { 66 | const IS_NULLABLE: bool = false; 67 | fn into_sql() -> DataType { 68 | if !matches!(*DIALECT, SQLite) { 69 | DataType::Varbinary(Some(N)) 70 | } else { 71 | DataType::Blob(None) 72 | } 73 | } 74 | } 75 | 76 | impl From for VarBinary 77 | where 78 | T: Into>, 79 | { 80 | fn from(obj: T) -> Self { 81 | VarBinary(obj.into()) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /models/src/types/var_char.rs: -------------------------------------------------------------------------------- 1 | use crate::{prelude::*, types::IntoSQL}; 2 | use models_parser::ast::DataType; 3 | #[cfg(feature = "serde")] 4 | use serde::*; 5 | use std::{ 6 | convert::AsMut, 7 | ops::{Deref, DerefMut}, 8 | }; 9 | 10 | /// Used for MySQL when to specify that the datatype should be 11 | /// a `VARCHAR(N)`. The database will make sure the field does not 12 | /// go over the specified length. 13 | /// ``` 14 | /// use models::{Model, VarChar}; 15 | /// #[derive(Model)] 16 | /// struct Profile { 17 | /// email: VarChar<255> 18 | /// } 19 | /// ``` 20 | /// The previous structure would generate: 21 | /// ```sql 22 | /// CREATE TABLE profile ( 23 | /// email VARCHAR(255) NOT NULL 24 | /// ); 25 | /// ``` 26 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 27 | #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] 28 | #[cfg_attr(feature = "sqlx", sqlx(transparent))] 29 | #[derive(Clone, Default, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] 30 | pub struct VarChar(pub String); 31 | 32 | impl VarChar { 33 | pub fn new() -> Self { 34 | Self::default() 35 | } 36 | } 37 | 38 | impl Deref for VarChar { 39 | type Target = String; 40 | fn deref(&self) -> &Self::Target { 41 | &self.0 42 | } 43 | } 44 | 45 | impl DerefMut for VarChar { 46 | fn deref_mut(&mut self) -> &mut Self::Target { 47 | &mut self.0 48 | } 49 | } 50 | 51 | impl AsRef for VarChar { 52 | fn as_ref(&self) -> &String { 53 | &self.0 54 | } 55 | } 56 | 57 | impl AsMut for VarChar { 58 | fn as_mut(&mut self) -> &mut String { 59 | &mut self.0 60 | } 61 | } 62 | 63 | impl From for VarChar 64 | where 65 | T: Into, 66 | { 67 | fn from(obj: T) -> Self { 68 | let string: String = obj.into(); 69 | VarChar(string) 70 | } 71 | } 72 | 73 | impl IntoSQL for VarChar { 74 | const IS_NULLABLE: bool = false; 75 | fn into_sql() -> DataType { 76 | match *DIALECT { 77 | SQLite => DataType::Text, 78 | _ => DataType::Varchar(Some(N)), 79 | } 80 | } 81 | } 82 | --------------------------------------------------------------------------------