├── .dockerignore ├── .github └── workflows │ ├── lint.yaml │ ├── publish-crate.yaml │ ├── publish-docker.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── Cargo.toml ├── Dockerfile ├── LICENSE.md ├── README.md ├── src ├── db.rs ├── helpers.rs ├── lib.rs ├── main.rs ├── migrations │ ├── add_column.rs │ ├── add_foreign_key.rs │ ├── add_index.rs │ ├── alter_column.rs │ ├── common.rs │ ├── create_enum.rs │ ├── create_table.rs │ ├── custom.rs │ ├── mod.rs │ ├── remove_column.rs │ ├── remove_enum.rs │ ├── remove_foreign_key.rs │ ├── remove_index.rs │ ├── remove_table.rs │ └── rename_table.rs ├── schema.rs └── state.rs └── tests ├── add_column.rs ├── add_foreign_key.rs ├── add_index.rs ├── alter_column.rs ├── common.rs ├── complex.rs ├── create_enum.rs ├── create_table.rs ├── custom.rs ├── failure.rs ├── remove_column.rs ├── remove_enum.rs ├── remove_foreign_key.rs ├── remove_index.rs ├── remove_table.rs └── rename_table.rs /.dockerignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | env: 11 | CARGO_TERM_COLOR: always 12 | 13 | jobs: 14 | lint: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | - name: Select Rust toolchain with Clippy 20 | uses: actions-rs/toolchain@v1 21 | with: 22 | toolchain: stable 23 | components: clippy 24 | override: true 25 | - name: Use cache for Rust dependencies 26 | uses: Swatinem/rust-cache@v2 27 | - name: Lint using Clippy 28 | run: cargo clippy 29 | -------------------------------------------------------------------------------- /.github/workflows/publish-crate.yaml: -------------------------------------------------------------------------------- 1 | name: Publish crate 2 | 3 | on: workflow_dispatch 4 | 5 | jobs: 6 | publish-crate: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v2 11 | - name: Select Rust toolchain 12 | uses: actions-rs/toolchain@v1 13 | with: 14 | toolchain: stable 15 | - name: Publish to crates.io 16 | uses: actions-rs/cargo@v1 17 | with: 18 | command: publish 19 | env: 20 | CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_IO_TOKEN }} 21 | -------------------------------------------------------------------------------- /.github/workflows/publish-docker.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: "Version (without 'v' prefix)" 8 | required: true 9 | type: string 10 | 11 | jobs: 12 | publish-docker: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v2 17 | - name: Log in to Docker 18 | env: 19 | DOCKER_PASSWORD: ${{ secrets.DOCKER_ACCESS_TOKEN }} 20 | run: | 21 | docker login -u ${{ secrets.DOCKER_USER }} -p $DOCKER_PASSWORD 22 | - name: Build Docker image 23 | run: docker build . --tag ${{ secrets.DOCKER_USER }}/reshape:${{ inputs.version }} --tag ${{ secrets.DOCKER_USER }}/reshape:latest 24 | - name: Push Docker image 25 | run: docker push --all-tags ${{ secrets.DOCKER_USER }}/reshape 26 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [ created ] 6 | 7 | env: 8 | CARGO_TERM_COLOR: always 9 | 10 | jobs: 11 | build-binary: 12 | strategy: 13 | matrix: 14 | include: 15 | - name: Linux 64-bit 16 | file-name: linux_amd64 17 | target: x86_64-unknown-linux-gnu 18 | host: ubuntu-latest 19 | use-cross: true 20 | - name: Linux 32-bit 21 | file-name: linux_386 22 | target: i686-unknown-linux-gnu 23 | host: ubuntu-latest 24 | use-cross: true 25 | - name: Linux ARM 64-bit 26 | file-name: linux_aarch64 27 | target: aarch64-unknown-linux-gnu 28 | host: ubuntu-latest 29 | use-cross: true 30 | - name: macOS 31 | file-name: darwin_amd64 32 | target: x86_64-apple-darwin 33 | host: macos-11 34 | use-cross: false 35 | - name: macOS Apple Silicon 36 | file-name: darwin_aarch64 37 | target: aarch64-apple-darwin 38 | host: macos-11 39 | use-cross: false 40 | 41 | runs-on: ${{ matrix.host }} 42 | 43 | steps: 44 | - name: Checkout 45 | uses: actions/checkout@v2 46 | - name: Select Rust toolchain 47 | uses: actions-rs/toolchain@v1 48 | with: 49 | toolchain: stable 50 | override: true 51 | target: ${{ matrix.target }} 52 | - name: Use cache for Rust dependencies 53 | uses: Swatinem/rust-cache@v1 54 | - name: Build 55 | uses: actions-rs/cargo@v1 56 | with: 57 | command: build 58 | args: --release --target ${{ matrix.target }} 59 | use-cross: ${{ matrix.use-cross }} 60 | - name: Rename binary 61 | run: mv target/${{ matrix.target }}/release/reshape ./reshape-${{ matrix.file-name }} 62 | - name: Upload binary to release 63 | env: 64 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 65 | run: gh release upload ${GITHUB_REF##*/} "reshape-${{ matrix.file-name }}#${{ matrix.name }}" --clobber 66 | 67 | publish-docker: 68 | runs-on: ubuntu-latest 69 | steps: 70 | - name: Checkout 71 | uses: actions/checkout@v2 72 | - name: Log in to Docker 73 | env: 74 | DOCKER_PASSWORD: ${{ secrets.DOCKER_ACCESS_TOKEN }} 75 | run: | 76 | docker login -u ${{ secrets.DOCKER_USER }} -p $DOCKER_PASSWORD 77 | - name: Build Docker image 78 | # GITHUB_REF is formatted as: refs/tags/v0.0.1 79 | # The shell expansion used below will remove everything up to the version number, leaving 0.0.1 80 | run: docker build . --tag ${{ secrets.DOCKER_USER }}/reshape:${GITHUB_REF##*/v} --tag ${{ secrets.DOCKER_USER }}/reshape:latest 81 | - name: Push Docker image 82 | run: docker push --all-tags ${{ secrets.DOCKER_USER }}/reshape 83 | 84 | publish-crate: 85 | runs-on: ubuntu-latest 86 | needs: ["build-binary", "publish-docker"] 87 | steps: 88 | - name: Checkout 89 | uses: actions/checkout@v2 90 | - name: Select Rust toolchain 91 | uses: actions-rs/toolchain@v1 92 | with: 93 | toolchain: stable 94 | - name: Publish to crates.io 95 | uses: actions-rs/cargo@v1 96 | with: 97 | command: publish 98 | env: 99 | CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_IO_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | env: 11 | CARGO_TERM_COLOR: always 12 | 13 | jobs: 14 | integration-tests: 15 | runs-on: ubuntu-latest 16 | 17 | services: 18 | postgres: 19 | image: postgres 20 | ports: 21 | - 5432:5432 22 | env: 23 | POSTGRES_DB: migra_test 24 | POSTGRES_USER: postgres 25 | POSTGRES_PASSWORD: postgres 26 | # Set health checks to wait until postgres has started 27 | options: >- 28 | --health-cmd pg_isready 29 | --health-interval 10s 30 | --health-timeout 5s 31 | --health-retries 5 32 | 33 | steps: 34 | - name: Checkout 35 | uses: actions/checkout@v2 36 | - name: Select Rust toolchain 37 | uses: actions-rs/toolchain@v1 38 | with: 39 | toolchain: stable 40 | - name: Use cache for Rust dependencies 41 | uses: Swatinem/rust-cache@v1 42 | - name: Run integration tests 43 | uses: actions-rs/cargo@v1 44 | with: 45 | command: test 46 | args: -- --test-threads=1 47 | env: 48 | POSTGRES_CONNECTION_STRING: "postgres://postgres:postgres@127.0.0.1/migra_test" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "reshape" 3 | version = "0.7.0" 4 | description = "An easy-to-use, zero-downtime schema migration tool for Postgres" 5 | homepage = "https://github.com/fabianlindfors/reshape" 6 | documentation = "https://github.com/fabianlindfors/reshape" 7 | repository = "https://github.com/fabianlindfors/reshape" 8 | license = "MIT" 9 | keywords = ["postgres", "migrations"] 10 | edition = "2021" 11 | authors = ["Fabian Lindfors"] 12 | rust-version = "1.70" 13 | 14 | [dependencies] 15 | postgres = { version = "0.19.2", features = ["with-serde_json-1"] } 16 | serde = { version = "1.0", features = ["derive"] } 17 | serde_json = "1.0" 18 | typetag = "0.1.7" 19 | anyhow = { version = "1.0.44", features = ["backtrace"] } 20 | clap = { version = "3.1.9", features = ["derive"] } 21 | toml = "0.5" 22 | version = "3.0.0" 23 | colored = "2" 24 | rand = "0.8" 25 | dotenv = "0.15.0" 26 | lexical-sort = "0.3.1" 27 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.70.0 AS builder 2 | WORKDIR /usr/src/reshape 3 | COPY . . 4 | RUN cargo build --release 5 | 6 | FROM debian:bullseye AS runtime 7 | WORKDIR /usr/share/app 8 | COPY --from=builder /usr/src/reshape/target/release/reshape /usr/local/bin/reshape 9 | CMD ["reshape"] -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fabian Lindfors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/db.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::min, time::Duration}; 2 | 3 | use anyhow::{anyhow, Context}; 4 | use postgres::{types::ToSql, NoTls, Row}; 5 | use rand::prelude::*; 6 | 7 | // DbLocker wraps a regular DbConn, only allowing access using the 8 | // `lock` method. This method will acquire the advisory lock before 9 | // allowing access to the database, and then release it afterwards. 10 | // 11 | // We use advisory locks to avoid multiple Reshape instances working 12 | // on the same database as the same time. DbLocker is the only way to 13 | // get a DbConn which ensures that all database access is protected by 14 | // a lock. 15 | // 16 | // Postgres docs on advisory locks: 17 | // https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS 18 | pub struct DbLocker { 19 | client: DbConn, 20 | } 21 | 22 | impl DbLocker { 23 | // Advisory lock keys in Postgres are 64-bit integers. 24 | // The key we use was chosen randomly. 25 | const LOCK_KEY: i64 = 4036779288569897133; 26 | 27 | pub fn connect(config: &postgres::Config) -> anyhow::Result { 28 | let mut pg = config.connect(NoTls)?; 29 | 30 | // When running DDL queries that acquire locks, we risk causing a "lock queue". 31 | // When attempting to acquire a lock, Postgres will wait for any long running queries to complete. 32 | // At the same time, it will block other queries until the lock has been acquired and released. 33 | // This has the bad effect of the long-running query blocking other queries because of us, forming 34 | // a queue of other queries until we release our lock. 35 | // 36 | // We set the lock_timeout setting to avoid this. This puts an upper bound for how long Postgres will 37 | // wait to acquire locks and also the maximum amount of time a long-running query can block other queries. 38 | // We should also add automatic retries to handle these timeouts gracefully. 39 | // 40 | // Reference: https://medium.com/paypal-tech/postgresql-at-scale-database-schema-changes-without-downtime-20d3749ed680 41 | // 42 | // TODO: Make lock_timeout configurable 43 | pg.simple_query("SET lock_timeout = '1s'") 44 | .context("failed to set lock_timeout")?; 45 | 46 | Ok(Self { 47 | client: DbConn::new(pg), 48 | }) 49 | } 50 | 51 | pub fn lock( 52 | &mut self, 53 | f: impl FnOnce(&mut DbConn) -> anyhow::Result<()>, 54 | ) -> anyhow::Result<()> { 55 | self.acquire_lock()?; 56 | let result = f(&mut self.client); 57 | self.release_lock()?; 58 | 59 | result 60 | } 61 | 62 | fn acquire_lock(&mut self) -> anyhow::Result<()> { 63 | let success = self 64 | .client 65 | .query(&format!("SELECT pg_try_advisory_lock({})", Self::LOCK_KEY))? 66 | .first() 67 | .ok_or_else(|| anyhow!("unexpectedly failed when acquiring advisory lock")) 68 | .map(|row| row.get::<'_, _, bool>(0))?; 69 | 70 | if success { 71 | Ok(()) 72 | } else { 73 | Err(anyhow!("another instance of Reshape is already running")) 74 | } 75 | } 76 | 77 | fn release_lock(&mut self) -> anyhow::Result<()> { 78 | self.client 79 | .query(&format!("SELECT pg_advisory_unlock({})", Self::LOCK_KEY))? 80 | .first() 81 | .ok_or_else(|| anyhow!("unexpectedly failed when releasing advisory lock"))?; 82 | Ok(()) 83 | } 84 | } 85 | 86 | pub trait Conn { 87 | fn run(&mut self, query: &str) -> anyhow::Result<()>; 88 | fn query(&mut self, query: &str) -> anyhow::Result>; 89 | fn query_with_params( 90 | &mut self, 91 | query: &str, 92 | params: &[&(dyn ToSql + Sync)], 93 | ) -> anyhow::Result>; 94 | fn transaction(&mut self) -> anyhow::Result; 95 | } 96 | 97 | pub struct DbConn { 98 | client: postgres::Client, 99 | } 100 | 101 | impl DbConn { 102 | fn new(client: postgres::Client) -> Self { 103 | DbConn { client } 104 | } 105 | } 106 | 107 | impl Conn for DbConn { 108 | fn run(&mut self, query: &str) -> anyhow::Result<()> { 109 | retry_automatically(|| self.client.batch_execute(query))?; 110 | Ok(()) 111 | } 112 | 113 | fn query(&mut self, query: &str) -> anyhow::Result> { 114 | let rows = retry_automatically(|| self.client.query(query, &[]))?; 115 | Ok(rows) 116 | } 117 | 118 | fn query_with_params( 119 | &mut self, 120 | query: &str, 121 | params: &[&(dyn ToSql + Sync)], 122 | ) -> anyhow::Result> { 123 | let rows = retry_automatically(|| self.client.query(query, params))?; 124 | Ok(rows) 125 | } 126 | 127 | fn transaction(&mut self) -> anyhow::Result { 128 | let transaction = self.client.transaction()?; 129 | Ok(Transaction { transaction }) 130 | } 131 | } 132 | 133 | pub struct Transaction<'a> { 134 | transaction: postgres::Transaction<'a>, 135 | } 136 | 137 | impl Transaction<'_> { 138 | pub fn commit(self) -> anyhow::Result<()> { 139 | self.transaction.commit()?; 140 | Ok(()) 141 | } 142 | 143 | pub fn rollback(self) -> anyhow::Result<()> { 144 | self.transaction.rollback()?; 145 | Ok(()) 146 | } 147 | } 148 | 149 | impl Conn for Transaction<'_> { 150 | fn run(&mut self, query: &str) -> anyhow::Result<()> { 151 | self.transaction.batch_execute(query)?; 152 | Ok(()) 153 | } 154 | 155 | fn query(&mut self, query: &str) -> anyhow::Result> { 156 | let rows = self.transaction.query(query, &[])?; 157 | Ok(rows) 158 | } 159 | 160 | fn query_with_params( 161 | &mut self, 162 | query: &str, 163 | params: &[&(dyn ToSql + Sync)], 164 | ) -> anyhow::Result> { 165 | let rows = self.transaction.query(query, params)?; 166 | Ok(rows) 167 | } 168 | 169 | fn transaction(&mut self) -> anyhow::Result { 170 | let transaction = self.transaction.transaction()?; 171 | Ok(Transaction { transaction }) 172 | } 173 | } 174 | 175 | // Retry a database operation with exponential backoff and jitter 176 | fn retry_automatically( 177 | mut f: impl FnMut() -> Result, 178 | ) -> Result { 179 | const STARTING_WAIT_TIME: u64 = 100; 180 | const MAX_WAIT_TIME: u64 = 3_200; 181 | const MAX_ATTEMPTS: u32 = 10; 182 | 183 | let mut rng = rand::thread_rng(); 184 | let mut attempts = 0; 185 | loop { 186 | let result = f(); 187 | 188 | let error = match result { 189 | Ok(_) => return result, 190 | Err(err) => err, 191 | }; 192 | 193 | // If we got a database error, we check if it's retryable. 194 | // If we didn't get a database error, then it's most likely some kind of connection 195 | // error which should also be retried. 196 | if let Some(db_error) = error.as_db_error() { 197 | if !error_retryable(db_error) { 198 | return Err(error); 199 | } 200 | } 201 | 202 | attempts += 1; 203 | if attempts >= MAX_ATTEMPTS { 204 | return Err(error); 205 | } 206 | 207 | // The wait time increases exponentially, starting at 100ms and doubling up to a max of 3.2s. 208 | let wait_time = min( 209 | MAX_WAIT_TIME, 210 | STARTING_WAIT_TIME * u64::pow(2, attempts - 1), 211 | ); 212 | 213 | // The jitter is up to half the wait time 214 | let jitter: u64 = rng.gen_range(0..wait_time / 2); 215 | 216 | std::thread::sleep(Duration::from_millis(wait_time + jitter)); 217 | } 218 | } 219 | 220 | // Check if a database error can be retried 221 | fn error_retryable(error: &postgres::error::DbError) -> bool { 222 | // LOCK_NOT_AVAILABLE is caused by lock_timeout being exceeded 223 | matches!(error.code(), &postgres::error::SqlState::LOCK_NOT_AVAILABLE) 224 | } 225 | -------------------------------------------------------------------------------- /src/helpers.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | 3 | use crate::db::Conn; 4 | 5 | pub fn set_up_helpers(db: &mut dyn Conn, target_migration: &str) -> anyhow::Result<()> { 6 | let query = format!( 7 | " 8 | CREATE OR REPLACE FUNCTION reshape.is_new_schema() 9 | RETURNS BOOLEAN AS $$ 10 | DECLARE 11 | setting TEXT := current_setting('reshape.is_new_schema', TRUE); 12 | setting_bool BOOLEAN := setting IS NOT NULL AND setting = 'YES'; 13 | BEGIN 14 | RETURN current_setting('search_path') = 'migration_{}' OR setting_bool; 15 | END 16 | $$ language 'plpgsql'; 17 | ", 18 | target_migration, 19 | ); 20 | db.query(&query) 21 | .context("failed creating helper function reshape.is_new_schema()")?; 22 | 23 | Ok(()) 24 | } 25 | 26 | pub fn tear_down_helpers(db: &mut dyn Conn) -> anyhow::Result<()> { 27 | db.query("DROP FUNCTION IF EXISTS reshape.is_new_schema;")?; 28 | Ok(()) 29 | } 30 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::{self, File}, 3 | io::Read, 4 | path::Path, 5 | }; 6 | 7 | use anyhow::Context; 8 | use clap::{Args, Parser}; 9 | use reshape::{ 10 | migrations::{Action, Migration}, 11 | Reshape, 12 | }; 13 | use serde::{Deserialize, Serialize}; 14 | 15 | #[derive(Parser)] 16 | #[clap(name = "Reshape", version, about)] 17 | struct Opts { 18 | #[clap(subcommand)] 19 | cmd: Command, 20 | } 21 | 22 | #[derive(Parser)] 23 | #[clap(about)] 24 | enum Command { 25 | #[clap(subcommand)] 26 | Migration(MigrationCommand), 27 | 28 | #[clap( 29 | about = "Output the query your application should use to select the right schema", 30 | display_order = 2 31 | )] 32 | SchemaQuery(FindMigrationsOptions), 33 | 34 | #[clap( 35 | about = "Deprecated. Use `reshape schema-query` instead", 36 | display_order = 3 37 | )] 38 | GenerateSchemaQuery(FindMigrationsOptions), 39 | 40 | #[clap( 41 | about = "Deprecated. Use `reshape migration start` instead", 42 | display_order = 4 43 | )] 44 | Migrate(MigrateOptions), 45 | #[clap( 46 | about = "Deprecated. Use `reshape migration complete` instead", 47 | display_order = 5 48 | )] 49 | Complete(ConnectionOptions), 50 | #[clap( 51 | about = "Deprecated. Use `reshape migration abort` instead", 52 | display_order = 6 53 | )] 54 | Abort(ConnectionOptions), 55 | } 56 | 57 | #[derive(Parser)] 58 | #[clap(about = "Commands for managing migrations", display_order = 1)] 59 | enum MigrationCommand { 60 | #[clap( 61 | about = "Starts a new migration, applying any migrations which haven't yet been applied", 62 | display_order = 1 63 | )] 64 | Start(MigrateOptions), 65 | 66 | #[clap(about = "Completes an in-progress migration", display_order = 2)] 67 | Complete(ConnectionOptions), 68 | 69 | #[clap( 70 | about = "Aborts an in-progress migration without losing any data", 71 | display_order = 3 72 | )] 73 | Abort(ConnectionOptions), 74 | } 75 | 76 | #[derive(Args)] 77 | struct MigrateOptions { 78 | // Some comment 79 | #[clap(long, short)] 80 | complete: bool, 81 | #[clap(flatten)] 82 | connection_options: ConnectionOptions, 83 | #[clap(flatten)] 84 | find_migrations_options: FindMigrationsOptions, 85 | } 86 | 87 | #[derive(Parser)] 88 | struct ConnectionOptions { 89 | #[clap(long)] 90 | url: Option, 91 | #[clap(long, default_value = "localhost")] 92 | host: String, 93 | #[clap(long, default_value = "5432")] 94 | port: u16, 95 | #[clap(long, short, default_value = "postgres")] 96 | database: String, 97 | #[clap(long, short, default_value = "postgres")] 98 | username: String, 99 | #[clap(long, short, default_value = "postgres")] 100 | password: String, 101 | } 102 | 103 | #[derive(Parser)] 104 | struct FindMigrationsOptions { 105 | #[clap(long, default_value = "migrations")] 106 | dirs: Vec, 107 | } 108 | 109 | fn main() -> anyhow::Result<()> { 110 | let opts: Opts = Opts::parse(); 111 | run(opts) 112 | } 113 | 114 | fn run(opts: Opts) -> anyhow::Result<()> { 115 | match opts.cmd { 116 | Command::Migration(MigrationCommand::Start(opts)) | Command::Migrate(opts) => { 117 | let mut reshape = reshape_from_connection_options(&opts.connection_options)?; 118 | let migrations = find_migrations(&opts.find_migrations_options)?; 119 | reshape.migrate(migrations)?; 120 | 121 | // Automatically complete migration if --complete flag is set 122 | if opts.complete { 123 | reshape.complete()?; 124 | } 125 | 126 | Ok(()) 127 | } 128 | Command::Migration(MigrationCommand::Complete(opts)) | Command::Complete(opts) => { 129 | let mut reshape = reshape_from_connection_options(&opts)?; 130 | reshape.complete() 131 | } 132 | Command::Migration(MigrationCommand::Abort(opts)) | Command::Abort(opts) => { 133 | let mut reshape = reshape_from_connection_options(&opts)?; 134 | reshape.abort() 135 | } 136 | Command::SchemaQuery(opts) | Command::GenerateSchemaQuery(opts) => { 137 | let migrations = find_migrations(&opts)?; 138 | let query = migrations 139 | .last() 140 | .map(|migration| reshape::schema_query_for_migration(&migration.name)); 141 | println!("{}", query.unwrap_or_else(|| "".to_string())); 142 | 143 | Ok(()) 144 | } 145 | } 146 | } 147 | 148 | fn reshape_from_connection_options(opts: &ConnectionOptions) -> anyhow::Result { 149 | // Load environment variables from .env file if it exists 150 | dotenv::dotenv().ok(); 151 | 152 | let url_env = std::env::var("DB_URL").ok(); 153 | let url = url_env.as_ref().or_else(|| opts.url.as_ref()); 154 | 155 | // Use the connection URL if it has been set 156 | if let Some(url) = url { 157 | return Reshape::new(url); 158 | } 159 | 160 | let host_env = std::env::var("DB_HOST").ok(); 161 | let host = host_env.as_ref().unwrap_or_else(|| &opts.host); 162 | 163 | let port = std::env::var("DB_PORT") 164 | .ok() 165 | .and_then(|port| port.parse::().ok()) 166 | .unwrap_or(opts.port); 167 | 168 | let username_env = std::env::var("DB_USERNAME").ok(); 169 | let username = username_env.as_ref().unwrap_or_else(|| &opts.username); 170 | 171 | let password_env = std::env::var("DB_PASSWORD").ok(); 172 | let password = password_env.as_ref().unwrap_or_else(|| &opts.password); 173 | 174 | let database_env = std::env::var("DB_NAME").ok(); 175 | let database = database_env.as_ref().unwrap_or_else(|| &opts.database); 176 | 177 | Reshape::new_with_options(host, port, database, username, password) 178 | } 179 | 180 | fn find_migrations(opts: &FindMigrationsOptions) -> anyhow::Result> { 181 | let search_paths = opts 182 | .dirs 183 | .iter() 184 | .map(Path::new) 185 | // Filter out all directories that don't exist 186 | .filter(|path| path.exists()); 187 | 188 | // Find all files in the search paths 189 | let mut file_paths = Vec::new(); 190 | for search_path in search_paths { 191 | let entries = fs::read_dir(search_path)?; 192 | for entry in entries { 193 | let path = entry?.path(); 194 | file_paths.push(path); 195 | } 196 | } 197 | 198 | // Sort all files by their file names (without extension) 199 | // The files are sorted naturally, e.g. "1_test_migration" < "10_test_migration" 200 | file_paths.sort_unstable_by(|path1, path2| { 201 | let file1 = path1.as_path().file_stem().unwrap().to_str().unwrap(); 202 | let file2 = path2.as_path().file_stem().unwrap().to_str().unwrap(); 203 | 204 | lexical_sort::natural_cmp(file1, file2) 205 | }); 206 | 207 | file_paths 208 | .iter() 209 | .map(|path| { 210 | let mut file = File::open(path)?; 211 | 212 | // Read file data 213 | let mut data = String::new(); 214 | file.read_to_string(&mut data)?; 215 | 216 | Ok((path, data)) 217 | }) 218 | .map(|result| { 219 | result.and_then(|(path, data)| { 220 | let extension = path.extension().and_then(|ext| ext.to_str()).unwrap(); 221 | let file_migration = 222 | decode_migration_file(&data, extension).with_context(|| { 223 | format!("failed to parse migration file {}", path.display()) 224 | })?; 225 | 226 | let file_name = path.file_stem().and_then(|name| name.to_str()).unwrap(); 227 | Ok(Migration { 228 | name: file_migration.name.unwrap_or_else(|| file_name.to_string()), 229 | description: file_migration.description, 230 | actions: file_migration.actions, 231 | }) 232 | }) 233 | }) 234 | .collect() 235 | } 236 | 237 | fn decode_migration_file(data: &str, extension: &str) -> anyhow::Result { 238 | let migration: FileMigration = match extension { 239 | "json" => serde_json::from_str(data)?, 240 | "toml" => toml::from_str(data)?, 241 | extension => { 242 | return Err(anyhow::anyhow!( 243 | "unrecognized file extension '{}'", 244 | extension 245 | )) 246 | } 247 | }; 248 | 249 | Ok(migration) 250 | } 251 | 252 | #[derive(Serialize, Deserialize)] 253 | struct FileMigration { 254 | name: Option, 255 | description: Option, 256 | actions: Vec>, 257 | } 258 | -------------------------------------------------------------------------------- /src/migrations/add_column.rs: -------------------------------------------------------------------------------- 1 | use super::{common, Action, Column, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::{bail, Context}; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct AddColumn { 11 | pub table: String, 12 | pub column: Column, 13 | pub up: Option, 14 | } 15 | 16 | #[derive(Serialize, Deserialize, Debug)] 17 | #[serde(untagged)] 18 | pub enum Transformation { 19 | Simple(String), 20 | Update { 21 | table: String, 22 | value: String, 23 | r#where: String, 24 | }, 25 | } 26 | 27 | impl AddColumn { 28 | fn temp_column_name(&self, ctx: &MigrationContext) -> String { 29 | format!( 30 | "{}_temp_column_{}_{}", 31 | ctx.prefix(), 32 | self.table, 33 | self.column.name, 34 | ) 35 | } 36 | 37 | fn trigger_name(&self, ctx: &MigrationContext) -> String { 38 | format!( 39 | "{}_add_column_{}_{}", 40 | ctx.prefix(), 41 | self.table, 42 | self.column.name 43 | ) 44 | } 45 | 46 | fn reverse_trigger_name(&self, ctx: &MigrationContext) -> String { 47 | format!( 48 | "{}_add_column_{}_{}_rev", 49 | ctx.prefix(), 50 | self.table, 51 | self.column.name 52 | ) 53 | } 54 | 55 | fn not_null_constraint_name(&self, ctx: &MigrationContext) -> String { 56 | format!( 57 | "{}_add_column_not_null_{}_{}", 58 | ctx.prefix(), 59 | self.table, 60 | self.column.name 61 | ) 62 | } 63 | } 64 | 65 | #[typetag::serde(name = "add_column")] 66 | impl Action for AddColumn { 67 | fn describe(&self) -> String { 68 | format!( 69 | "Adding column \"{}\" to \"{}\"", 70 | self.column.name, self.table 71 | ) 72 | } 73 | 74 | fn run( 75 | &self, 76 | ctx: &MigrationContext, 77 | db: &mut dyn Conn, 78 | schema: &Schema, 79 | ) -> anyhow::Result<()> { 80 | let table = schema.get_table(db, &self.table)?; 81 | let temp_column_name = self.temp_column_name(ctx); 82 | 83 | let mut definition_parts = vec![ 84 | format!("\"{}\"", temp_column_name.to_string()), 85 | self.column.data_type.to_string(), 86 | ]; 87 | 88 | if let Some(default) = &self.column.default { 89 | definition_parts.push("DEFAULT".to_string()); 90 | definition_parts.push(default.to_string()); 91 | } 92 | 93 | if let Some(generated) = &self.column.generated { 94 | definition_parts.push("GENERATED".to_string()); 95 | definition_parts.push(generated.to_string()); 96 | } 97 | 98 | // Add column as NOT NULL 99 | let query = format!( 100 | r#" 101 | ALTER TABLE "{table}" 102 | ADD COLUMN IF NOT EXISTS {definition}; 103 | "#, 104 | table = self.table, 105 | definition = definition_parts.join(" "), 106 | ); 107 | db.run(&query).context("failed to add column")?; 108 | 109 | let declarations: Vec = table 110 | .columns 111 | .iter() 112 | .map(|column| { 113 | format!( 114 | "\"{alias}\" public.{table}.{real_name}%TYPE := NEW.{real_name};", 115 | table = table.real_name, 116 | alias = column.name, 117 | real_name = column.real_name, 118 | ) 119 | }) 120 | .collect(); 121 | 122 | if let Some(up) = &self.up { 123 | if let Transformation::Simple(up) = up { 124 | // Add triggers to fill in values as they are inserted/updated 125 | let query = format!( 126 | r#" 127 | CREATE OR REPLACE FUNCTION {trigger_name}() 128 | RETURNS TRIGGER AS $$ 129 | BEGIN 130 | IF NOT reshape.is_new_schema() THEN 131 | DECLARE 132 | {declarations} 133 | BEGIN 134 | NEW."{temp_column_name}" = {up}; 135 | END; 136 | END IF; 137 | RETURN NEW; 138 | END 139 | $$ language 'plpgsql'; 140 | 141 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; 142 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 143 | "#, 144 | temp_column_name = temp_column_name, 145 | trigger_name = self.trigger_name(ctx), 146 | up = up, 147 | table = self.table, 148 | declarations = declarations.join("\n"), 149 | ); 150 | db.run(&query).context("failed to create up trigger")?; 151 | 152 | // Backfill values in batches 153 | common::batch_touch_rows(db, &table.real_name, Some(&temp_column_name)) 154 | .context("failed to batch update existing rows")?; 155 | } 156 | 157 | if let Transformation::Update { 158 | table: from_table, 159 | value, 160 | r#where, 161 | } = up 162 | { 163 | let existing_schema_name = match &ctx.existing_schema_name { 164 | Some(name) => name, 165 | None => bail!("can't use update without previous migration"), 166 | }; 167 | 168 | let from_table = schema.get_table(db, &from_table)?; 169 | 170 | let from_table_assignments: Vec = from_table 171 | .columns 172 | .iter() 173 | .map(|column| { 174 | format!( 175 | "{table}.{alias} = NEW.{real_name};", 176 | table = from_table.name, 177 | alias = column.name, 178 | real_name = column.real_name, 179 | ) 180 | }) 181 | .collect(); 182 | 183 | // Add triggers to fill in values as they are inserted/updated 184 | let query = format!( 185 | r#" 186 | CREATE OR REPLACE FUNCTION {trigger_name}() 187 | RETURNS TRIGGER AS $$ 188 | #variable_conflict use_variable 189 | BEGIN 190 | IF NOT reshape.is_new_schema() THEN 191 | DECLARE 192 | {from_table} migration_{existing_schema_name}.{from_table}%ROWTYPE; 193 | BEGIN 194 | {assignments} 195 | 196 | -- Don't trigger reverse trigger when making this update 197 | perform set_config('reshape.disable_triggers', 'TRUE', TRUE); 198 | 199 | UPDATE public."{changed_table_real}" 200 | SET "{temp_column_name}" = {value} 201 | WHERE {where}; 202 | 203 | perform set_config('reshape.disable_triggers', '', TRUE); 204 | END; 205 | END IF; 206 | RETURN NEW; 207 | END 208 | $$ language 'plpgsql'; 209 | 210 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}"; 211 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 212 | "#, 213 | assignments = from_table_assignments.join("\n"), 214 | changed_table_real = table.real_name, 215 | from_table = from_table.name, 216 | from_table_real = from_table.real_name, 217 | trigger_name = self.trigger_name(ctx), 218 | // declarations = from_table_declarations.join("\n"), 219 | temp_column_name = temp_column_name, 220 | ); 221 | db.run(&query).context("failed to create up trigger")?; 222 | 223 | let from_table_columns = from_table 224 | .columns 225 | .iter() 226 | .map(|column| format!("{} as {}", column.real_name, column.name)) 227 | .collect::>() 228 | .join(", "); 229 | 230 | let changed_table_assignments: Vec = table 231 | .columns 232 | .iter() 233 | .map(|column| { 234 | format!( 235 | "{table}.{alias} := NEW.{real_name};", 236 | table = table.name, 237 | alias = column.name, 238 | real_name = column.real_name, 239 | ) 240 | }) 241 | .collect(); 242 | 243 | // Add triggers to fill in values as they are inserted/updated 244 | let query = format!( 245 | r#" 246 | CREATE OR REPLACE FUNCTION {trigger_name}() 247 | RETURNS TRIGGER AS $$ 248 | #variable_conflict use_variable 249 | BEGIN 250 | IF NOT reshape.is_new_schema() AND NOT current_setting('reshape.disable_triggers', TRUE) = 'TRUE' THEN 251 | DECLARE 252 | {changed_table} migration_{existing_schema_name}.{changed_table}%ROWTYPE; 253 | __temp_row migration_{existing_schema_name}.{from_table}%ROWTYPE; 254 | BEGIN 255 | {changed_table_assignments} 256 | 257 | SELECT {from_table_columns} 258 | INTO __temp_row 259 | FROM migration_{existing_schema_name}.{from_table} {from_table} 260 | WHERE {where}; 261 | 262 | DECLARE 263 | {from_table} migration_{existing_schema_name}.{from_table}%ROWTYPE; 264 | BEGIN 265 | {from_table} = __temp_row; 266 | NEW.{temp_column_name} = {value}; 267 | END; 268 | END; 269 | END IF; 270 | RETURN NEW; 271 | END 272 | $$ language 'plpgsql'; 273 | 274 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{changed_table_real}"; 275 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{changed_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 276 | "#, 277 | changed_table_assignments = changed_table_assignments.join("\n"), 278 | changed_table_real = table.real_name, 279 | changed_table = table.name, 280 | from_table = from_table.name, 281 | trigger_name = self.reverse_trigger_name(ctx), 282 | temp_column_name = temp_column_name, 283 | // declarations = declarations.join("\n"), 284 | ); 285 | db.run(&query) 286 | .context("failed to create reverse up trigger")?; 287 | 288 | // Backfill values in batches by touching the from table 289 | common::batch_touch_rows(db, &from_table.real_name, None) 290 | .context("failed to batch update existing rows")?; 291 | } 292 | } 293 | 294 | // Add a temporary NOT NULL constraint if the column shouldn't be nullable. 295 | // This constraint is set as NOT VALID so it doesn't apply to existing rows and 296 | // the existing rows don't need to be scanned under an exclusive lock. 297 | // Thanks to this, we can set the full column as NOT NULL later with minimal locking. 298 | if !self.column.nullable { 299 | let query = format!( 300 | r#" 301 | ALTER TABLE "{table}" 302 | ADD CONSTRAINT "{constraint_name}" 303 | CHECK ("{column}" IS NOT NULL) NOT VALID 304 | "#, 305 | table = self.table, 306 | constraint_name = self.not_null_constraint_name(ctx), 307 | column = temp_column_name, 308 | ); 309 | db.run(&query) 310 | .context("failed to add NOT NULL constraint")?; 311 | } 312 | 313 | Ok(()) 314 | } 315 | 316 | fn complete<'a>( 317 | &self, 318 | ctx: &MigrationContext, 319 | db: &'a mut dyn Conn, 320 | ) -> anyhow::Result>> { 321 | let mut transaction = db.transaction().context("failed to create transaction")?; 322 | 323 | // Remove triggers and procedures 324 | let query = format!( 325 | r#" 326 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 327 | DROP FUNCTION IF EXISTS "{reverse_trigger_name}" CASCADE; 328 | "#, 329 | trigger_name = self.trigger_name(ctx), 330 | reverse_trigger_name = self.reverse_trigger_name(ctx), 331 | ); 332 | transaction 333 | .run(&query) 334 | .context("failed to drop up trigger")?; 335 | 336 | // Update column to be NOT NULL if necessary 337 | if !self.column.nullable { 338 | // Validate the temporary constraint (should always be valid). 339 | // This performs a sequential scan but does not take an exclusive lock. 340 | let query = format!( 341 | r#" 342 | ALTER TABLE "{table}" 343 | VALIDATE CONSTRAINT "{constraint_name}" 344 | "#, 345 | table = self.table, 346 | constraint_name = self.not_null_constraint_name(ctx), 347 | ); 348 | transaction 349 | .run(&query) 350 | .context("failed to validate NOT NULL constraint")?; 351 | 352 | // Update the column to be NOT NULL. 353 | // This requires an exclusive lock but since PG 12 it can check 354 | // the existing constraint for correctness which makes the lock short-lived. 355 | // Source: https://dba.stackexchange.com/a/268128 356 | let query = format!( 357 | r#" 358 | ALTER TABLE "{table}" 359 | ALTER COLUMN "{column}" SET NOT NULL 360 | "#, 361 | table = self.table, 362 | column = self.temp_column_name(ctx), 363 | ); 364 | transaction 365 | .run(&query) 366 | .context("failed to set column as NOT NULL")?; 367 | 368 | // Drop the temporary constraint 369 | let query = format!( 370 | r#" 371 | ALTER TABLE "{table}" 372 | DROP CONSTRAINT "{constraint_name}" 373 | "#, 374 | table = self.table, 375 | constraint_name = self.not_null_constraint_name(ctx), 376 | ); 377 | transaction 378 | .run(&query) 379 | .context("failed to drop NOT NULL constraint")?; 380 | } 381 | 382 | // Rename the temporary column to its real name 383 | transaction 384 | .run(&format!( 385 | r#" 386 | ALTER TABLE "{table}" 387 | RENAME COLUMN "{temp_column_name}" TO "{column_name}" 388 | "#, 389 | table = self.table, 390 | temp_column_name = self.temp_column_name(ctx), 391 | column_name = self.column.name, 392 | )) 393 | .context("failed to rename column to final name")?; 394 | 395 | Ok(Some(transaction)) 396 | } 397 | 398 | fn update_schema(&self, ctx: &MigrationContext, schema: &mut Schema) { 399 | schema.change_table(&self.table, |table_changes| { 400 | table_changes.change_column(&self.column.name, |column_changes| { 401 | column_changes.set_column(&self.temp_column_name(ctx)); 402 | }) 403 | }); 404 | } 405 | 406 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 407 | // Remove column 408 | let query = format!( 409 | r#" 410 | ALTER TABLE "{table}" 411 | DROP COLUMN IF EXISTS "{column}" 412 | "#, 413 | table = self.table, 414 | column = self.temp_column_name(ctx), 415 | ); 416 | db.run(&query).context("failed to drop column")?; 417 | 418 | // Remove triggers and procedures 419 | let query = format!( 420 | r#" 421 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 422 | DROP FUNCTION IF EXISTS "{reverse_trigger_name}" CASCADE; 423 | "#, 424 | trigger_name = self.trigger_name(ctx), 425 | reverse_trigger_name = self.reverse_trigger_name(ctx), 426 | ); 427 | db.run(&query).context("failed to drop up trigger")?; 428 | 429 | Ok(()) 430 | } 431 | } 432 | -------------------------------------------------------------------------------- /src/migrations/add_foreign_key.rs: -------------------------------------------------------------------------------- 1 | use super::{common::ForeignKey, Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct AddForeignKey { 11 | pub table: String, 12 | foreign_key: ForeignKey, 13 | } 14 | 15 | #[typetag::serde(name = "add_foreign_key")] 16 | impl Action for AddForeignKey { 17 | fn describe(&self) -> String { 18 | format!( 19 | "Adding foreign key from table \"{}\" to \"{}\"", 20 | self.table, self.foreign_key.referenced_table 21 | ) 22 | } 23 | 24 | fn run( 25 | &self, 26 | ctx: &MigrationContext, 27 | db: &mut dyn Conn, 28 | schema: &Schema, 29 | ) -> anyhow::Result<()> { 30 | let table = schema.get_table(db, &self.table)?; 31 | let referenced_table = schema.get_table(db, &self.foreign_key.referenced_table)?; 32 | 33 | // Add quotes around all column names 34 | let columns: Vec = table 35 | .real_column_names(&self.foreign_key.columns) 36 | .map(|col| format!("\"{}\"", col)) 37 | .collect(); 38 | let referenced_columns: Vec = referenced_table 39 | .real_column_names(&self.foreign_key.referenced_columns) 40 | .map(|col| format!("\"{}\"", col)) 41 | .collect(); 42 | 43 | // Create foreign key but set is as NOT VALID. 44 | // This means the foreign key will be enforced for inserts and updates 45 | // but the existing data won't be checked, that would cause a long-lived lock. 46 | db.run(&format!( 47 | r#" 48 | ALTER TABLE "{table}" 49 | ADD CONSTRAINT {constraint_name} 50 | FOREIGN KEY ({columns}) 51 | REFERENCES "{referenced_table}" ({referenced_columns}) 52 | NOT VALID 53 | "#, 54 | table = table.real_name, 55 | constraint_name = self.temp_constraint_name(ctx), 56 | columns = columns.join(", "), 57 | referenced_table = referenced_table.real_name, 58 | referenced_columns = referenced_columns.join(", "), 59 | )) 60 | .context("failed to create foreign key")?; 61 | 62 | db.run(&format!( 63 | r#" 64 | ALTER TABLE "{table}" 65 | VALIDATE CONSTRAINT "{constraint_name}" 66 | "#, 67 | table = table.real_name, 68 | constraint_name = self.temp_constraint_name(ctx), 69 | )) 70 | .context("failed to validate foreign key")?; 71 | 72 | Ok(()) 73 | } 74 | 75 | fn complete<'a>( 76 | &self, 77 | ctx: &MigrationContext, 78 | db: &'a mut dyn Conn, 79 | ) -> anyhow::Result>> { 80 | db.run(&format!( 81 | r#" 82 | ALTER TABLE {table} 83 | RENAME CONSTRAINT {temp_constraint_name} TO {constraint_name} 84 | "#, 85 | table = self.table, 86 | temp_constraint_name = self.temp_constraint_name(ctx), 87 | constraint_name = self.final_constraint_name(), 88 | )) 89 | .context("failed to rename temporary constraint")?; 90 | Ok(None) 91 | } 92 | 93 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 94 | 95 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 96 | db.run(&format!( 97 | r#" 98 | ALTER TABLE "{table}" 99 | DROP CONSTRAINT IF EXISTS "{constraint_name}" 100 | "#, 101 | table = self.table, 102 | constraint_name = self.temp_constraint_name(ctx), 103 | )) 104 | .context("failed to validate foreign key")?; 105 | 106 | Ok(()) 107 | } 108 | } 109 | 110 | impl AddForeignKey { 111 | fn temp_constraint_name(&self, ctx: &MigrationContext) -> String { 112 | format!("{}_temp_fkey", ctx.prefix()) 113 | } 114 | 115 | fn final_constraint_name(&self) -> String { 116 | format!( 117 | "{table}_{columns}_fkey", 118 | table = self.table, 119 | columns = self.foreign_key.columns.join("_") 120 | ) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/migrations/add_index.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct AddIndex { 11 | pub table: String, 12 | pub index: Index, 13 | } 14 | 15 | #[derive(Serialize, Deserialize, Clone, Debug)] 16 | pub struct Index { 17 | pub name: String, 18 | pub columns: Vec, 19 | #[serde(default)] 20 | pub unique: bool, 21 | #[serde(rename = "type")] 22 | pub index_type: Option, 23 | } 24 | 25 | #[typetag::serde(name = "add_index")] 26 | impl Action for AddIndex { 27 | fn describe(&self) -> String { 28 | format!( 29 | "Adding index \"{}\" to table \"{}\"", 30 | self.index.name, self.table 31 | ) 32 | } 33 | 34 | fn run( 35 | &self, 36 | _ctx: &MigrationContext, 37 | db: &mut dyn Conn, 38 | schema: &Schema, 39 | ) -> anyhow::Result<()> { 40 | let table = schema.get_table(db, &self.table)?; 41 | 42 | let column_real_names: Vec = table 43 | .columns 44 | .iter() 45 | .filter(|column| self.index.columns.contains(&column.name)) 46 | .map(|column| format!("\"{}\"", column.real_name)) 47 | .collect(); 48 | 49 | let unique = if self.index.unique { "UNIQUE" } else { "" }; 50 | let index_type_def = if let Some(index_type) = &self.index.index_type { 51 | format!("USING {index_type}") 52 | } else { 53 | "".to_string() 54 | }; 55 | 56 | db.run(&format!( 57 | r#" 58 | CREATE {unique} INDEX CONCURRENTLY "{name}" ON "{table}" {index_type_def} ({columns}) 59 | "#, 60 | name = self.index.name, 61 | table = self.table, 62 | columns = column_real_names.join(", "), 63 | )) 64 | .context("failed to create index")?; 65 | Ok(()) 66 | } 67 | 68 | fn complete<'a>( 69 | &self, 70 | _ctx: &MigrationContext, 71 | _db: &'a mut dyn Conn, 72 | ) -> anyhow::Result>> { 73 | Ok(None) 74 | } 75 | 76 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 77 | 78 | fn abort(&self, _ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 79 | db.run(&format!( 80 | r#" 81 | DROP INDEX CONCURRENTLY IF EXISTS "{name}" 82 | "#, 83 | name = self.index.name, 84 | )) 85 | .context("failed to drop index")?; 86 | Ok(()) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/migrations/alter_column.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | migrations::common, 5 | schema::Schema, 6 | }; 7 | use anyhow::{anyhow, Context}; 8 | use serde::{Deserialize, Serialize}; 9 | 10 | #[derive(Serialize, Deserialize, Debug)] 11 | pub struct AlterColumn { 12 | pub table: String, 13 | pub column: String, 14 | pub up: Option, 15 | pub down: Option, 16 | #[serde(default)] 17 | pub changes: ColumnChanges, 18 | } 19 | 20 | #[derive(Serialize, Deserialize, Default, Debug)] 21 | pub struct ColumnChanges { 22 | pub name: Option, 23 | #[serde(rename = "type")] 24 | pub data_type: Option, 25 | pub nullable: Option, 26 | pub default: Option, 27 | } 28 | 29 | #[typetag::serde(name = "alter_column")] 30 | impl Action for AlterColumn { 31 | fn describe(&self) -> String { 32 | format!("Altering column \"{}\" on \"{}\"", self.column, self.table) 33 | } 34 | 35 | fn run( 36 | &self, 37 | ctx: &MigrationContext, 38 | db: &mut dyn Conn, 39 | schema: &Schema, 40 | ) -> anyhow::Result<()> { 41 | // If we are only changing the name of a column, we don't have to do anything at this stage 42 | // We'll set the new schema to point to the old column. When the migration is completed, 43 | // we rename the actual column. 44 | if self.can_short_circuit() { 45 | return Ok(()); 46 | } 47 | 48 | let table = schema.get_table(db, &self.table)?; 49 | 50 | let column = table 51 | .get_column(&self.column) 52 | .ok_or_else(|| anyhow!("no such column {} exists", self.column))?; 53 | 54 | let temporary_column_name = self.temporary_column_name(ctx); 55 | let temporary_column_type = self.changes.data_type.as_ref().unwrap_or(&column.data_type); 56 | 57 | // Add temporary, nullable column 58 | let mut temp_column_definition_parts: Vec<&str> = 59 | vec![&temporary_column_name, temporary_column_type]; 60 | 61 | // Use either new default value or existing one if one exists 62 | let default_value = self 63 | .changes 64 | .default 65 | .as_ref() 66 | .or_else(|| column.default.as_ref()); 67 | if let Some(default) = default_value { 68 | temp_column_definition_parts.push("DEFAULT"); 69 | temp_column_definition_parts.push(default); 70 | } 71 | 72 | let query = format!( 73 | r#" 74 | ALTER TABLE "{table}" 75 | ADD COLUMN IF NOT EXISTS {temp_column_definition} 76 | "#, 77 | table = self.table, 78 | temp_column_definition = temp_column_definition_parts.join(" "), 79 | ); 80 | db.run(&query).context("failed to add temporary column")?; 81 | 82 | // If up or down wasn't provided, we default to simply moving the value over. 83 | // This is the correct behaviour for example when only changing the default value. 84 | let up = self.up.as_ref().unwrap_or(&self.column); 85 | let down = self.down.as_ref().unwrap_or(&self.column); 86 | 87 | let declarations: Vec = table 88 | .columns 89 | .iter() 90 | .filter(|column| column.name != self.column) 91 | .map(|column| { 92 | format!( 93 | "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", 94 | table = table.real_name, 95 | alias = column.name, 96 | real_name = column.real_name, 97 | ) 98 | }) 99 | .collect(); 100 | 101 | let query = format!( 102 | r#" 103 | CREATE OR REPLACE FUNCTION {up_trigger}() 104 | RETURNS TRIGGER AS $$ 105 | BEGIN 106 | IF NOT reshape.is_new_schema() THEN 107 | DECLARE 108 | {declarations} 109 | {existing_column} public.{table}.{existing_column_real}%TYPE := NEW.{existing_column_real}; 110 | BEGIN 111 | NEW.{temp_column} = {up}; 112 | END; 113 | END IF; 114 | RETURN NEW; 115 | END 116 | $$ language 'plpgsql'; 117 | 118 | DROP TRIGGER IF EXISTS "{up_trigger}" ON "{table}"; 119 | CREATE TRIGGER "{up_trigger}" BEFORE INSERT OR UPDATE ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {up_trigger}(); 120 | 121 | CREATE OR REPLACE FUNCTION {down_trigger}() 122 | RETURNS TRIGGER AS $$ 123 | BEGIN 124 | IF reshape.is_new_schema() THEN 125 | DECLARE 126 | {declarations} 127 | {existing_column} public.{table}.{temp_column}%TYPE := NEW.{temp_column}; 128 | BEGIN 129 | NEW.{existing_column_real} = {down}; 130 | END; 131 | END IF; 132 | RETURN NEW; 133 | END 134 | $$ language 'plpgsql'; 135 | 136 | DROP TRIGGER IF EXISTS "{down_trigger}" ON "{table}"; 137 | CREATE TRIGGER "{down_trigger}" BEFORE INSERT OR UPDATE ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {down_trigger}(); 138 | "#, 139 | existing_column = &self.column, 140 | existing_column_real = column.real_name, 141 | temp_column = self.temporary_column_name(ctx), 142 | up = up, 143 | down = down, 144 | table = self.table, 145 | up_trigger = self.up_trigger_name(ctx), 146 | down_trigger = self.down_trigger_name(ctx), 147 | declarations = declarations.join("\n"), 148 | ); 149 | db.run(&query) 150 | .context("failed to create up and down triggers")?; 151 | 152 | // Backfill values in batches by touching the previous column 153 | common::batch_touch_rows(db, &table.real_name, Some(&column.real_name)) 154 | .context("failed to batch update existing rows")?; 155 | 156 | // Duplicate any indices to the temporary column 157 | let indices = common::get_indices_for_column(db, &table.real_name, &column.real_name)?; 158 | for index in indices { 159 | let index_columns: Vec = common::get_index_columns(db, &index.name)? 160 | .into_iter() 161 | .map(|idx_column| { 162 | // Replace column with temporary column for new index 163 | if idx_column == column.real_name { 164 | temporary_column_name.to_string() 165 | } else { 166 | idx_column 167 | } 168 | }) 169 | .collect(); 170 | let temp_index_name = self.temp_index_name(ctx, index.oid); 171 | 172 | let unique_def = if index.unique { "UNIQUE" } else { "" }; 173 | 174 | db.query(&format!( 175 | r#" 176 | CREATE {unique_def} INDEX CONCURRENTLY IF NOT EXISTS "{new_index_name}" ON "{table}" USING {index_type} ({columns}) 177 | "#, 178 | new_index_name = temp_index_name, 179 | table = table.real_name, 180 | columns = index_columns.join(", "), 181 | index_type = index.index_type, 182 | )) 183 | .context("failed to create temporary index")?; 184 | } 185 | 186 | // Add a temporary NOT NULL constraint if the column shouldn't be nullable. 187 | // This constraint is set as NOT VALID so it doesn't apply to existing rows and 188 | // the existing rows don't need to be scanned under an exclusive lock. 189 | // Thanks to this, we can set the full column as NOT NULL later with minimal locking. 190 | if !self.changes.nullable.unwrap_or(column.nullable) { 191 | let query = format!( 192 | r#" 193 | ALTER TABLE "{table}" 194 | ADD CONSTRAINT "{constraint_name}" 195 | CHECK ("{column}" IS NOT NULL) NOT VALID 196 | "#, 197 | table = self.table, 198 | constraint_name = self.not_null_constraint_name(ctx), 199 | column = self.temporary_column_name(ctx), 200 | ); 201 | db.run(&query) 202 | .context("failed to add NOT NULL constraint")?; 203 | } 204 | 205 | Ok(()) 206 | } 207 | 208 | fn complete<'a>( 209 | &self, 210 | ctx: &MigrationContext, 211 | db: &'a mut dyn Conn, 212 | ) -> anyhow::Result>> { 213 | if self.can_short_circuit() { 214 | if let Some(new_name) = &self.changes.name { 215 | let query = format!( 216 | r#" 217 | ALTER TABLE "{table}" 218 | RENAME COLUMN "{existing_name}" TO "{new_name}" 219 | "#, 220 | table = self.table, 221 | existing_name = self.column, 222 | new_name = new_name, 223 | ); 224 | db.run(&query).context("failed to rename column")?; 225 | } 226 | return Ok(None); 227 | } 228 | 229 | // Update column to be NOT NULL if necessary 230 | let has_not_null_constraint = !db 231 | .query_with_params( 232 | " 233 | SELECT constraint_name 234 | FROM information_schema.constraint_column_usage 235 | WHERE constraint_name = $1 236 | ", 237 | &[&self.not_null_constraint_name(ctx)], 238 | ) 239 | .context("failed to get any NOT NULL constraint")? 240 | .is_empty(); 241 | if has_not_null_constraint { 242 | // Validate the temporary constraint (should always be valid). 243 | // This performs a sequential scan but does not take an exclusive lock. 244 | let query = format!( 245 | r#" 246 | ALTER TABLE "{table}" 247 | VALIDATE CONSTRAINT "{constraint_name}" 248 | "#, 249 | table = self.table, 250 | constraint_name = self.not_null_constraint_name(ctx), 251 | ); 252 | db.run(&query) 253 | .context("failed to validate NOT NULL constraint")?; 254 | 255 | // Update the column to be NOT NULL. 256 | // This requires an exclusive lock but since PG 12 it can check 257 | // the existing constraint for correctness which makes the lock short-lived. 258 | // Source: https://dba.stackexchange.com/a/268128 259 | let query = format!( 260 | r#" 261 | ALTER TABLE "{table}" 262 | ALTER COLUMN "{column}" SET NOT NULL 263 | "#, 264 | table = self.table, 265 | column = self.temporary_column_name(ctx), 266 | ); 267 | db.run(&query).context("failed to set column as NOT NULL")?; 268 | 269 | // Drop the temporary constraint 270 | let query = format!( 271 | r#" 272 | ALTER TABLE "{table}" 273 | DROP CONSTRAINT "{constraint_name}" 274 | "#, 275 | table = self.table, 276 | constraint_name = self.not_null_constraint_name(ctx), 277 | ); 278 | db.run(&query) 279 | .context("failed to drop NOT NULL constraint")?; 280 | } 281 | 282 | // Replace old indices with the new temporary ones created for the temporary column 283 | let indices = common::get_indices_for_column(db, &self.table, &self.column)?; 284 | for current_index in indices { 285 | // To keep the index handling idempotent, we need to do the following: 286 | // 1. Add a prefix to the existing index 287 | // 2. Rename temporary index to its final name 288 | // 3. Drop existing index concurrently 289 | 290 | // Add prefix (if not already added) to existing index 291 | let prefix = "__reshape_old"; 292 | let target_index_name = current_index.name.trim_start_matches(prefix); 293 | let old_index_name = format!("{}_{}", prefix, target_index_name); 294 | db.query(&format!( 295 | r#" 296 | ALTER INDEX IF EXISTS "{current_name}" RENAME TO "{new_name}" 297 | "#, 298 | current_name = target_index_name, 299 | new_name = old_index_name, 300 | )) 301 | .context("failed to rename old index")?; 302 | 303 | // Rename temporary index to real name 304 | let temp_index_name = self.temp_index_name(ctx, current_index.oid); 305 | db.query(&format!( 306 | r#" 307 | ALTER INDEX IF EXISTS "{temp_index_name}" RENAME TO "{target_index_name}" 308 | "#, 309 | temp_index_name = temp_index_name, 310 | target_index_name = target_index_name, 311 | )) 312 | .context("failed to rename temporary index")?; 313 | 314 | // Drop old index concurrently 315 | db.query(&format!( 316 | r#" 317 | DROP INDEX CONCURRENTLY IF EXISTS "{old_index_name}" 318 | "#, 319 | old_index_name = old_index_name, 320 | )) 321 | .context("failed to drop old index")?; 322 | } 323 | 324 | // Remove old column 325 | let query = format!( 326 | r#" 327 | ALTER TABLE "{table}" DROP COLUMN IF EXISTS "{column}" CASCADE 328 | "#, 329 | table = self.table, 330 | column = self.column, 331 | ); 332 | db.run(&query).context("failed to drop old column")?; 333 | 334 | // Rename temporary column 335 | let column_name = self.changes.name.as_deref().unwrap_or(&self.column); 336 | let query = format!( 337 | r#" 338 | ALTER TABLE "{table}" RENAME COLUMN "{temp_column}" TO "{name}" 339 | "#, 340 | table = self.table, 341 | temp_column = self.temporary_column_name(ctx), 342 | name = column_name, 343 | ); 344 | db.run(&query) 345 | .context("failed to rename temporary column")?; 346 | 347 | // Remove triggers and procedures 348 | let query = format!( 349 | r#" 350 | DROP TRIGGER IF EXISTS "{up_trigger}" ON "{table}"; 351 | DROP FUNCTION IF EXISTS "{up_trigger}"; 352 | 353 | DROP TRIGGER IF EXISTS "{down_trigger}" ON "{table}"; 354 | DROP FUNCTION IF EXISTS "{down_trigger}"; 355 | "#, 356 | table = self.table, 357 | up_trigger = self.up_trigger_name(ctx), 358 | down_trigger = self.down_trigger_name(ctx), 359 | ); 360 | db.run(&query) 361 | .context("failed to drop up and down triggers")?; 362 | 363 | Ok(None) 364 | } 365 | 366 | fn update_schema(&self, ctx: &MigrationContext, schema: &mut Schema) { 367 | // If we are only changing the name of a column, we haven't created a temporary column 368 | // Instead, we rename the schema column but point it to the old column 369 | if self.can_short_circuit() { 370 | if let Some(new_name) = &self.changes.name { 371 | schema.change_table(&self.table, |table_changes| { 372 | table_changes.change_column(&self.column, |column_changes| { 373 | column_changes.set_name(new_name); 374 | }); 375 | }); 376 | } 377 | 378 | return; 379 | } 380 | 381 | schema.change_table(&self.table, |table_changes| { 382 | table_changes.change_column(&self.column, |column_changes| { 383 | column_changes.set_column(&self.temporary_column_name(ctx)); 384 | }); 385 | }); 386 | } 387 | 388 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 389 | // Safely remove any indices created for the temporary column 390 | let temp_column_name = self.temporary_column_name(ctx); 391 | let indices = common::get_indices_for_column(db, &self.table, &temp_column_name)?; 392 | for index in indices { 393 | let temp_index_name = self.temp_index_name(ctx, index.oid); 394 | db.query(&format!( 395 | r#" 396 | DROP INDEX CONCURRENTLY IF EXISTS "{index_name}" 397 | "#, 398 | index_name = temp_index_name, 399 | ))?; 400 | } 401 | 402 | // Drop temporary column 403 | let query = format!( 404 | r#" 405 | ALTER TABLE "{table}" 406 | DROP COLUMN IF EXISTS "{temp_column}"; 407 | "#, 408 | table = self.table, 409 | temp_column = self.temporary_column_name(ctx), 410 | ); 411 | db.run(&query).context("failed to drop temporary column")?; 412 | 413 | // Remove triggers and procedures 414 | let query = format!( 415 | r#" 416 | DROP TRIGGER IF EXISTS "{up_trigger}" ON "{table}"; 417 | DROP FUNCTION IF EXISTS "{up_trigger}"; 418 | 419 | DROP TRIGGER IF EXISTS "{down_trigger}" ON "{table}"; 420 | DROP FUNCTION IF EXISTS "{down_trigger}"; 421 | "#, 422 | table = self.table, 423 | up_trigger = self.up_trigger_name(ctx), 424 | down_trigger = self.down_trigger_name(ctx), 425 | ); 426 | db.run(&query) 427 | .context("failed to drop up and down triggers")?; 428 | 429 | Ok(()) 430 | } 431 | } 432 | 433 | impl AlterColumn { 434 | fn temporary_column_name(&self, ctx: &MigrationContext) -> String { 435 | format!("{}_new_{}", ctx.prefix(), self.column) 436 | } 437 | 438 | fn up_trigger_name(&self, ctx: &MigrationContext) -> String { 439 | format!("{}_alter_column_up_trigger", ctx.prefix()) 440 | } 441 | 442 | fn down_trigger_name(&self, ctx: &MigrationContext) -> String { 443 | format!("{}_alter_column_down_trigger", ctx.prefix_inverse()) 444 | } 445 | 446 | fn not_null_constraint_name(&self, ctx: &MigrationContext) -> String { 447 | format!("{}_alter_column_temporary", ctx.prefix()) 448 | } 449 | 450 | fn temp_index_name(&self, ctx: &MigrationContext, index_oid: u32) -> String { 451 | format!("{}_alter_column_temp_index_{}", ctx.prefix(), index_oid) 452 | } 453 | 454 | fn can_short_circuit(&self) -> bool { 455 | self.changes.name.is_some() 456 | && self.changes.data_type.is_none() 457 | && self.changes.nullable.is_none() 458 | && self.changes.default.is_none() 459 | } 460 | } 461 | -------------------------------------------------------------------------------- /src/migrations/common.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use postgres::types::{FromSql, ToSql}; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::db::Conn; 6 | 7 | #[derive(Serialize, Deserialize, Clone, Debug)] 8 | pub struct Column { 9 | pub name: String, 10 | #[serde(rename = "type")] 11 | pub data_type: String, 12 | #[serde(default = "nullable_default")] 13 | pub nullable: bool, 14 | pub default: Option, 15 | pub generated: Option, 16 | } 17 | 18 | fn nullable_default() -> bool { 19 | true 20 | } 21 | 22 | #[derive(Serialize, Deserialize, Clone, Debug)] 23 | pub struct ForeignKey { 24 | pub columns: Vec, 25 | pub referenced_table: String, 26 | pub referenced_columns: Vec, 27 | } 28 | 29 | #[derive(Debug)] 30 | struct PostgresRawValue { 31 | bytes: Vec, 32 | } 33 | 34 | impl<'a> FromSql<'a> for PostgresRawValue { 35 | fn from_sql( 36 | _ty: &postgres::types::Type, 37 | raw: &'a [u8], 38 | ) -> Result> { 39 | Ok(PostgresRawValue { 40 | bytes: raw.to_vec(), 41 | }) 42 | } 43 | 44 | fn accepts(_ty: &postgres::types::Type) -> bool { 45 | true 46 | } 47 | } 48 | 49 | impl ToSql for PostgresRawValue { 50 | fn to_sql( 51 | &self, 52 | _ty: &postgres::types::Type, 53 | out: &mut postgres::types::private::BytesMut, 54 | ) -> Result> 55 | where 56 | Self: Sized, 57 | { 58 | out.extend_from_slice(&self.bytes); 59 | Ok(postgres::types::IsNull::No) 60 | } 61 | 62 | fn accepts(_ty: &postgres::types::Type) -> bool 63 | where 64 | Self: Sized, 65 | { 66 | true 67 | } 68 | 69 | postgres::types::to_sql_checked!(); 70 | } 71 | 72 | pub fn batch_touch_rows( 73 | db: &mut dyn Conn, 74 | table: &str, 75 | column: Option<&str>, 76 | ) -> anyhow::Result<()> { 77 | const BATCH_SIZE: u16 = 1000; 78 | 79 | let mut cursor: Option = None; 80 | 81 | loop { 82 | let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new(); 83 | 84 | let primary_key = get_primary_key_columns_for_table(db, table)?; 85 | 86 | // If no column to touch is passed, we default to the first primary key column (just to make some "update") 87 | let touched_column = match column { 88 | Some(column) => column, 89 | None => primary_key.first().unwrap(), 90 | }; 91 | 92 | let primary_key_columns = primary_key.join(", "); 93 | 94 | let primary_key_where = primary_key 95 | .iter() 96 | .map(|column| { 97 | format!( 98 | r#" 99 | "{table}"."{column}" = rows."{column}" 100 | "#, 101 | table = table, 102 | column = column, 103 | ) 104 | }) 105 | .collect::>() 106 | .join(" AND "); 107 | 108 | let returning_columns = primary_key 109 | .iter() 110 | .map(|column| format!("rows.\"{}\"", column)) 111 | .collect::>() 112 | .join(", "); 113 | 114 | let cursor_where = if let Some(cursor) = &cursor { 115 | params.push(cursor); 116 | 117 | format!( 118 | "WHERE ({primary_key_columns}) > $1", 119 | primary_key_columns = primary_key_columns 120 | ) 121 | } else { 122 | "".to_string() 123 | }; 124 | 125 | let query = format!( 126 | r#" 127 | WITH rows AS ( 128 | SELECT {primary_key_columns} 129 | FROM public."{table}" 130 | {cursor_where} 131 | ORDER BY {primary_key_columns} 132 | LIMIT {batch_size} 133 | ), update AS ( 134 | UPDATE public."{table}" "{table}" 135 | SET "{touched_column}" = "{table}"."{touched_column}" 136 | FROM rows 137 | WHERE {primary_key_where} 138 | RETURNING {returning_columns} 139 | ) 140 | SELECT LAST_VALUE(({primary_key_columns})) OVER () AS last_value 141 | FROM update 142 | LIMIT 1 143 | "#, 144 | batch_size = BATCH_SIZE, 145 | ); 146 | let last_value = db 147 | .query_with_params(&query, ¶ms)? 148 | .first() 149 | .and_then(|row| row.get("last_value")); 150 | 151 | if last_value.is_none() { 152 | break; 153 | } 154 | 155 | cursor = last_value 156 | } 157 | 158 | Ok(()) 159 | } 160 | 161 | fn get_primary_key_columns_for_table( 162 | db: &mut dyn Conn, 163 | table: &str, 164 | ) -> anyhow::Result> { 165 | // Query from https://wiki.postgresql.org/wiki/Retrieve_primary_key_columns 166 | let primary_key_columns: Vec = db 167 | .query(&format!( 168 | " 169 | SELECT a.attname AS column_name 170 | FROM pg_index i 171 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) 172 | WHERE i.indrelid = '{table}'::regclass 173 | AND i.indisprimary; 174 | ", 175 | table = table 176 | ))? 177 | .iter() 178 | .map(|row| row.get("column_name")) 179 | .collect(); 180 | 181 | Ok(primary_key_columns) 182 | } 183 | 184 | pub struct Index { 185 | pub name: String, 186 | pub oid: u32, 187 | pub unique: bool, 188 | pub index_type: String, 189 | } 190 | 191 | pub fn get_indices_for_column( 192 | db: &mut dyn Conn, 193 | table: &str, 194 | column: &str, 195 | ) -> anyhow::Result> { 196 | let indices = db 197 | .query(&format!( 198 | " 199 | SELECT 200 | i.relname AS name, 201 | i.oid AS oid, 202 | ix.indisunique AS unique, 203 | am.amname AS type 204 | FROM pg_index ix 205 | JOIN pg_class t ON t.oid = ix.indrelid 206 | JOIN pg_class i ON i.oid = ix.indexrelid 207 | JOIN pg_am am ON i.relam = am.oid 208 | JOIN pg_attribute a ON 209 | a.attrelid = t.oid AND 210 | a.attnum = ANY(ix.indkey) 211 | WHERE 212 | t.relname = '{table}' AND 213 | a.attname = '{column}' 214 | ", 215 | table = table, 216 | column = column, 217 | ))? 218 | .iter() 219 | .map(|row| Index { 220 | name: row.get("name"), 221 | oid: row.get("oid"), 222 | unique: row.get("unique"), 223 | index_type: row.get("type"), 224 | }) 225 | .collect(); 226 | 227 | Ok(indices) 228 | } 229 | 230 | pub fn get_index_columns(db: &mut dyn Conn, index_name: &str) -> anyhow::Result> { 231 | // Get all columns which are part of the index in order 232 | let (table_oid, column_nums) = db 233 | .query(&format!( 234 | " 235 | SELECT t.oid AS table_oid, ix.indkey::INTEGER[] AS columns 236 | FROM pg_index ix 237 | JOIN pg_class t ON t.oid = ix.indrelid 238 | JOIN pg_class i ON i.oid = ix.indexrelid 239 | WHERE 240 | i.relname = '{index_name}' 241 | ", 242 | index_name = index_name, 243 | ))? 244 | .first() 245 | .map(|row| { 246 | ( 247 | row.get::<'_, _, u32>("table_oid"), 248 | row.get::<'_, _, Vec>("columns"), 249 | ) 250 | }) 251 | .ok_or_else(|| anyhow!("failed to get columns for index"))?; 252 | 253 | // Get the name of each of the columns, still in order 254 | column_nums 255 | .iter() 256 | .map(|column_num| -> anyhow::Result { 257 | let name: String = db 258 | .query(&format!( 259 | " 260 | SELECT attname AS name 261 | FROM pg_attribute 262 | WHERE attrelid = {table_oid} 263 | AND attnum = {column_num}; 264 | ", 265 | table_oid = table_oid, 266 | column_num = column_num, 267 | ))? 268 | .first() 269 | .map(|row| row.get("name")) 270 | .ok_or_else(|| anyhow!("expected to find column"))?; 271 | 272 | Ok(name) 273 | }) 274 | .collect::>>() 275 | } 276 | -------------------------------------------------------------------------------- /src/migrations/create_enum.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct CreateEnum { 11 | pub name: String, 12 | pub values: Vec, 13 | } 14 | 15 | #[typetag::serde(name = "create_enum")] 16 | impl Action for CreateEnum { 17 | fn describe(&self) -> String { 18 | format!("Creating enum \"{}\"", self.name) 19 | } 20 | 21 | fn run( 22 | &self, 23 | _ctx: &MigrationContext, 24 | db: &mut dyn Conn, 25 | _schema: &Schema, 26 | ) -> anyhow::Result<()> { 27 | // Check if enum already exists. CREATE TYPE doesn't have 28 | // a IF NOT EXISTS option so we have to do it manually. 29 | let enum_exists = !db 30 | .query(&format!( 31 | " 32 | SELECT typname 33 | FROM pg_catalog.pg_type 34 | WHERE typcategory = 'E' 35 | AND typname = '{name}' 36 | ", 37 | name = self.name, 38 | ))? 39 | .is_empty(); 40 | if enum_exists { 41 | return Ok(()); 42 | } 43 | 44 | let values_def: Vec = self 45 | .values 46 | .iter() 47 | .map(|value| format!("'{}'", value)) 48 | .collect(); 49 | 50 | db.run(&format!( 51 | r#" 52 | CREATE TYPE "{name}" AS ENUM ({values}) 53 | "#, 54 | name = self.name, 55 | values = values_def.join(", "), 56 | )) 57 | .context("failed to create enum")?; 58 | 59 | Ok(()) 60 | } 61 | 62 | fn complete<'a>( 63 | &self, 64 | _ctx: &MigrationContext, 65 | _db: &'a mut dyn Conn, 66 | ) -> anyhow::Result>> { 67 | Ok(None) 68 | } 69 | 70 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 71 | 72 | fn abort(&self, _ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 73 | db.run(&format!( 74 | r#" 75 | DROP TYPE IF EXISTS {name} 76 | "#, 77 | name = self.name, 78 | )) 79 | .context("failed to drop enum")?; 80 | 81 | Ok(()) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/migrations/create_table.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use super::{common::ForeignKey, Action, Column, MigrationContext}; 4 | use crate::{ 5 | db::{Conn, Transaction}, 6 | migrations::common, 7 | schema::Schema, 8 | }; 9 | use anyhow::Context; 10 | use serde::{Deserialize, Serialize}; 11 | 12 | #[derive(Serialize, Deserialize, Debug)] 13 | pub struct CreateTable { 14 | pub name: String, 15 | pub columns: Vec, 16 | pub primary_key: Vec, 17 | 18 | #[serde(default)] 19 | pub foreign_keys: Vec, 20 | 21 | pub up: Option, 22 | } 23 | 24 | #[derive(Serialize, Deserialize, Debug)] 25 | pub struct Transformation { 26 | table: String, 27 | values: HashMap, 28 | upsert_constraint: Option, 29 | } 30 | 31 | impl CreateTable { 32 | fn trigger_name(&self, ctx: &MigrationContext) -> String { 33 | format!("{}_create_table_{}", ctx.prefix(), self.name) 34 | } 35 | } 36 | 37 | #[typetag::serde(name = "create_table")] 38 | impl Action for CreateTable { 39 | fn describe(&self) -> String { 40 | format!("Creating table \"{}\"", self.name) 41 | } 42 | 43 | fn run( 44 | &self, 45 | ctx: &MigrationContext, 46 | db: &mut dyn Conn, 47 | schema: &Schema, 48 | ) -> anyhow::Result<()> { 49 | let mut definition_rows: Vec = self 50 | .columns 51 | .iter() 52 | .map(|column| { 53 | let mut parts = vec![format!("\"{}\"", column.name), column.data_type.to_string()]; 54 | 55 | if let Some(default) = &column.default { 56 | parts.push("DEFAULT".to_string()); 57 | parts.push(default.to_string()); 58 | } 59 | 60 | if !column.nullable { 61 | parts.push("NOT NULL".to_string()); 62 | } 63 | 64 | if let Some(generated) = &column.generated { 65 | parts.push("GENERATED".to_string()); 66 | parts.push(generated.to_string()); 67 | } 68 | 69 | parts.join(" ") 70 | }) 71 | .collect(); 72 | 73 | let primary_key_columns = self 74 | .primary_key 75 | .iter() 76 | // Add quotes around all column names 77 | .map(|col| format!("\"{}\"", col)) 78 | .collect::>() 79 | .join(", "); 80 | definition_rows.push(format!("PRIMARY KEY ({})", primary_key_columns)); 81 | 82 | for foreign_key in &self.foreign_keys { 83 | // Add quotes around all column names 84 | let columns: Vec = foreign_key 85 | .columns 86 | .iter() 87 | .map(|col| format!("\"{}\"", col)) 88 | .collect(); 89 | 90 | let referenced_table = schema.get_table(db, &foreign_key.referenced_table)?; 91 | let referenced_columns: Vec = referenced_table 92 | .real_column_names(&foreign_key.referenced_columns) 93 | .map(|col| format!("\"{}\"", col)) 94 | .collect(); 95 | 96 | definition_rows.push(format!( 97 | r#" 98 | FOREIGN KEY ({columns}) REFERENCES "{table}" ({referenced_columns}) 99 | "#, 100 | columns = columns.join(", "), 101 | table = referenced_table.real_name, 102 | referenced_columns = referenced_columns.join(", "), 103 | )); 104 | } 105 | 106 | let query = &format!( 107 | r#" 108 | CREATE TABLE "{name}" ( 109 | {definition} 110 | ) 111 | "#, 112 | name = self.name, 113 | definition = definition_rows.join(",\n"), 114 | ); 115 | db.run(query).context("failed to create table")?; 116 | 117 | if let Some(Transformation { 118 | table: from_table, 119 | values, 120 | upsert_constraint, 121 | }) = &self.up 122 | { 123 | let from_table = schema.get_table(db, &from_table)?; 124 | 125 | let declarations: Vec = from_table 126 | .columns 127 | .iter() 128 | .map(|column| { 129 | format!( 130 | "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", 131 | table = from_table.real_name, 132 | alias = column.name, 133 | real_name = column.real_name, 134 | ) 135 | }) 136 | .collect(); 137 | 138 | let (insert_columns, insert_values): (Vec<&str>, Vec<&str>) = values 139 | .iter() 140 | .map(|(k, v)| -> (&str, &str) { (k, v) }) // Force &String to &str 141 | .unzip(); 142 | 143 | let update_set: Vec = values 144 | .iter() 145 | .map(|(field, value)| format!("\"{field}\" = {value}")) 146 | .collect(); 147 | 148 | // Constraint to check for conflicts. Defaults to the primary key constraint. 149 | let conflict_constraint_name = match upsert_constraint { 150 | Some(custom_constraint) => custom_constraint.clone(), 151 | _ => format!("{table}_pkey", table = self.name), 152 | }; 153 | 154 | // Add triggers to fill in values as they are inserted/updated 155 | let query = format!( 156 | r#" 157 | CREATE OR REPLACE FUNCTION {trigger_name}() 158 | RETURNS TRIGGER AS $$ 159 | #variable_conflict use_variable 160 | BEGIN 161 | IF NOT reshape.is_new_schema() THEN 162 | DECLARE 163 | {declarations} 164 | BEGIN 165 | INSERT INTO public."{changed_table_real}" ({columns}) 166 | VALUES ({values}) 167 | ON CONFLICT ON CONSTRAINT "{conflict_constraint_name}" 168 | DO UPDATE SET 169 | {updates}; 170 | END; 171 | END IF; 172 | RETURN NEW; 173 | END 174 | $$ language 'plpgsql'; 175 | 176 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}"; 177 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 178 | "#, 179 | changed_table_real = self.name, 180 | from_table_real = from_table.real_name, 181 | trigger_name = self.trigger_name(ctx), 182 | declarations = declarations.join("\n"), 183 | columns = insert_columns.join(", "), 184 | values = insert_values.join(", "), 185 | updates = update_set.join(",\n"), 186 | ); 187 | db.run(&query).context("failed to create up trigger")?; 188 | 189 | // Backfill values in batches by touching the from table 190 | common::batch_touch_rows(db, &from_table.real_name, None) 191 | .context("failed to batch update existing rows")?; 192 | } 193 | 194 | Ok(()) 195 | } 196 | 197 | fn complete<'a>( 198 | &self, 199 | ctx: &MigrationContext, 200 | db: &'a mut dyn Conn, 201 | ) -> anyhow::Result>> { 202 | // Remove triggers and procedures 203 | let query = format!( 204 | r#" 205 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 206 | "#, 207 | trigger_name = self.trigger_name(ctx), 208 | ); 209 | db.run(&query).context("failed to drop up trigger")?; 210 | 211 | Ok(None) 212 | } 213 | 214 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 215 | 216 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 217 | // Remove triggers and procedures 218 | let query = format!( 219 | r#" 220 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 221 | "#, 222 | trigger_name = self.trigger_name(ctx), 223 | ); 224 | db.run(&query).context("failed to drop up trigger")?; 225 | 226 | db.run(&format!( 227 | r#" 228 | DROP TABLE IF EXISTS "{name}" 229 | "#, 230 | name = self.name, 231 | )) 232 | .context("failed to drop table")?; 233 | 234 | Ok(()) 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/migrations/custom.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | #[derive(Serialize, Deserialize, Debug)] 9 | pub struct Custom { 10 | #[serde(default)] 11 | pub start: Option, 12 | 13 | #[serde(default)] 14 | pub complete: Option, 15 | 16 | #[serde(default)] 17 | pub abort: Option, 18 | } 19 | 20 | #[typetag::serde(name = "custom")] 21 | impl Action for Custom { 22 | fn describe(&self) -> String { 23 | "Running custom migration".to_string() 24 | } 25 | 26 | fn run( 27 | &self, 28 | _ctx: &MigrationContext, 29 | db: &mut dyn Conn, 30 | _schema: &Schema, 31 | ) -> anyhow::Result<()> { 32 | if let Some(start_query) = &self.start { 33 | println!("Running query: {}", start_query); 34 | db.run(start_query)?; 35 | } 36 | 37 | Ok(()) 38 | } 39 | 40 | fn complete<'a>( 41 | &self, 42 | _ctx: &MigrationContext, 43 | db: &'a mut dyn Conn, 44 | ) -> anyhow::Result>> { 45 | if let Some(complete_query) = &self.complete { 46 | db.run(complete_query)?; 47 | } 48 | 49 | Ok(None) 50 | } 51 | 52 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 53 | 54 | fn abort(&self, _ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 55 | if let Some(abort_query) = &self.abort { 56 | db.run(abort_query)?; 57 | } 58 | 59 | Ok(()) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/migrations/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | db::{Conn, Transaction}, 3 | schema::Schema, 4 | }; 5 | use core::fmt::Debug; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | // Re-export migration types 9 | mod common; 10 | pub use common::Column; 11 | 12 | mod create_table; 13 | pub use create_table::CreateTable; 14 | 15 | mod alter_column; 16 | pub use alter_column::{AlterColumn, ColumnChanges}; 17 | 18 | mod add_column; 19 | pub use add_column::AddColumn; 20 | 21 | mod remove_column; 22 | pub use remove_column::RemoveColumn; 23 | 24 | mod add_index; 25 | pub use add_index::{AddIndex, Index}; 26 | 27 | mod remove_index; 28 | pub use remove_index::RemoveIndex; 29 | 30 | mod remove_table; 31 | pub use remove_table::RemoveTable; 32 | 33 | mod rename_table; 34 | pub use rename_table::RenameTable; 35 | 36 | mod create_enum; 37 | pub use create_enum::CreateEnum; 38 | 39 | mod remove_enum; 40 | pub use remove_enum::RemoveEnum; 41 | 42 | mod custom; 43 | pub use custom::Custom; 44 | 45 | mod add_foreign_key; 46 | pub use add_foreign_key::AddForeignKey; 47 | 48 | mod remove_foreign_key; 49 | pub use remove_foreign_key::RemoveForeignKey; 50 | 51 | #[derive(Serialize, Deserialize, Debug)] 52 | pub struct Migration { 53 | pub name: String, 54 | pub description: Option, 55 | pub actions: Vec>, 56 | } 57 | 58 | impl Migration { 59 | pub fn new(name: impl Into, description: Option) -> Migration { 60 | Migration { 61 | name: name.into(), 62 | description, 63 | actions: vec![], 64 | } 65 | } 66 | 67 | pub fn with_action(mut self, action: impl Action + 'static) -> Self { 68 | self.actions.push(Box::new(action)); 69 | self 70 | } 71 | } 72 | 73 | impl PartialEq for Migration { 74 | fn eq(&self, other: &Self) -> bool { 75 | self.name == other.name 76 | } 77 | } 78 | 79 | impl Eq for Migration {} 80 | 81 | impl Clone for Migration { 82 | fn clone(&self) -> Self { 83 | let serialized = serde_json::to_string(self).unwrap(); 84 | serde_json::from_str(&serialized).unwrap() 85 | } 86 | } 87 | 88 | pub struct MigrationContext { 89 | migration_index: usize, 90 | action_index: usize, 91 | existing_schema_name: Option, 92 | } 93 | 94 | impl MigrationContext { 95 | pub fn new( 96 | migration_index: usize, 97 | action_index: usize, 98 | existing_schema_name: Option, 99 | ) -> Self { 100 | MigrationContext { 101 | migration_index, 102 | action_index, 103 | existing_schema_name, 104 | } 105 | } 106 | 107 | fn prefix(&self) -> String { 108 | format!( 109 | "__reshape_{:0>4}_{:0>4}", 110 | self.migration_index, self.action_index 111 | ) 112 | } 113 | 114 | fn prefix_inverse(&self) -> String { 115 | format!( 116 | "__reshape_{:0>4}_{:0>4}", 117 | 1000 - self.migration_index, 118 | 1000 - self.action_index 119 | ) 120 | } 121 | } 122 | 123 | #[typetag::serde(tag = "type")] 124 | pub trait Action: Debug { 125 | fn describe(&self) -> String; 126 | fn run(&self, ctx: &MigrationContext, db: &mut dyn Conn, schema: &Schema) 127 | -> anyhow::Result<()>; 128 | fn complete<'a>( 129 | &self, 130 | ctx: &MigrationContext, 131 | db: &'a mut dyn Conn, 132 | ) -> anyhow::Result>>; 133 | fn update_schema(&self, ctx: &MigrationContext, schema: &mut Schema); 134 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()>; 135 | } 136 | -------------------------------------------------------------------------------- /src/migrations/remove_column.rs: -------------------------------------------------------------------------------- 1 | use super::{common, Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::{anyhow, bail, Context}; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RemoveColumn { 11 | pub table: String, 12 | pub column: String, 13 | pub down: Option, 14 | } 15 | 16 | #[derive(Serialize, Deserialize, Debug)] 17 | #[serde(untagged)] 18 | pub enum Transformation { 19 | Simple(String), 20 | Update { 21 | table: String, 22 | value: String, 23 | r#where: String, 24 | }, 25 | } 26 | 27 | impl RemoveColumn { 28 | fn trigger_name(&self, ctx: &MigrationContext) -> String { 29 | format!( 30 | "{}_remove_column_{}_{}", 31 | ctx.prefix(), 32 | self.table, 33 | self.column 34 | ) 35 | } 36 | 37 | fn reverse_trigger_name(&self, ctx: &MigrationContext) -> String { 38 | format!( 39 | "{}_remove_column_{}_{}_rev", 40 | ctx.prefix(), 41 | self.table, 42 | self.column 43 | ) 44 | } 45 | 46 | fn not_null_constraint_trigger_name(&self, ctx: &MigrationContext) -> String { 47 | format!( 48 | "{}_remove_column_{}_{}_nn", 49 | ctx.prefix(), 50 | self.table, 51 | self.column 52 | ) 53 | } 54 | 55 | fn not_null_constraint_name(&self, ctx: &MigrationContext) -> String { 56 | format!( 57 | "{}_add_column_not_null_{}_{}", 58 | ctx.prefix(), 59 | self.table, 60 | self.column 61 | ) 62 | } 63 | } 64 | 65 | #[typetag::serde(name = "remove_column")] 66 | impl Action for RemoveColumn { 67 | fn describe(&self) -> String { 68 | format!( 69 | "Removing column \"{}\" from \"{}\"", 70 | self.column, self.table 71 | ) 72 | } 73 | 74 | fn run( 75 | &self, 76 | ctx: &MigrationContext, 77 | db: &mut dyn Conn, 78 | schema: &Schema, 79 | ) -> anyhow::Result<()> { 80 | let table = schema.get_table(db, &self.table)?; 81 | let column = table 82 | .get_column(&self.column) 83 | .ok_or_else(|| anyhow!("no such column {} exists", self.column))?; 84 | 85 | // Add down trigger 86 | if let Some(down) = &self.down { 87 | let declarations: Vec = table 88 | .columns 89 | .iter() 90 | .map(|column| { 91 | format!( 92 | "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", 93 | table = table.real_name, 94 | alias = column.name, 95 | real_name = column.real_name, 96 | ) 97 | }) 98 | .collect(); 99 | 100 | if let Transformation::Simple(down) = down { 101 | let query = format!( 102 | r#" 103 | CREATE OR REPLACE FUNCTION {trigger_name}() 104 | RETURNS TRIGGER AS $$ 105 | BEGIN 106 | IF reshape.is_new_schema() THEN 107 | DECLARE 108 | {declarations} 109 | BEGIN 110 | NEW.{column_name} = {down}; 111 | END; 112 | END IF; 113 | RETURN NEW; 114 | END 115 | $$ language 'plpgsql'; 116 | 117 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; 118 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 119 | "#, 120 | column_name = self.column, 121 | trigger_name = self.trigger_name(ctx), 122 | down = down, 123 | table = self.table, 124 | declarations = declarations.join("\n"), 125 | ); 126 | db.run(&query).context("failed to create down trigger")?; 127 | } 128 | 129 | if let Transformation::Update { 130 | table: from_table, 131 | value, 132 | r#where, 133 | } = down 134 | { 135 | let existing_schema_name = match &ctx.existing_schema_name { 136 | Some(name) => name, 137 | None => bail!("can't use update without previous migration"), 138 | }; 139 | 140 | let from_table = schema.get_table(db, &from_table)?; 141 | 142 | let maybe_null_check = if !column.nullable { 143 | // Replace NOT NULL constraint with a constraint trigger that only triggers on the old schema. 144 | // We will add a null check to the down function on the new schema below as well to cover both cases. 145 | // As we are using a complex down function, we must remove the NOT NULL check for the new schema. 146 | // NOT NULL is not checked at the end of a transaction, but immediately upon update. 147 | let query = format!( 148 | r#" 149 | CREATE OR REPLACE FUNCTION {trigger_name}() 150 | RETURNS TRIGGER AS $$ 151 | BEGIN 152 | IF NOT reshape.is_new_schema() THEN 153 | IF NEW.{column} IS NULL THEN 154 | RAISE EXCEPTION '{column} can not be null'; 155 | END IF; 156 | END IF; 157 | RETURN NEW; 158 | END 159 | $$ language 'plpgsql'; 160 | 161 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; 162 | 163 | CREATE CONSTRAINT TRIGGER "{trigger_name}" 164 | AFTER INSERT OR UPDATE 165 | ON "{table}" 166 | FOR EACH ROW 167 | EXECUTE PROCEDURE {trigger_name}(); 168 | "#, 169 | table = self.table, 170 | trigger_name = self.not_null_constraint_trigger_name(ctx), 171 | column = self.column, 172 | ); 173 | db.run(&query) 174 | .context("failed to create null constraint trigger")?; 175 | 176 | db.run(&format!( 177 | r#" 178 | ALTER TABLE {table} 179 | ALTER COLUMN {column} 180 | DROP NOT NULL 181 | "#, 182 | table = self.table, 183 | column = self.column 184 | )) 185 | .context("failed to remove column not null constraint")?; 186 | 187 | format!( 188 | r#" 189 | IF {value} IS NULL THEN 190 | RAISE EXCEPTION '{column_name} can not be null'; 191 | END IF; 192 | "#, 193 | column_name = self.column, 194 | ) 195 | } else { 196 | "".to_string() 197 | }; 198 | 199 | let into_variables = from_table 200 | .columns 201 | .iter() 202 | .map(|column| { 203 | format!( 204 | "NEW.{real_name} AS {alias}", 205 | alias = column.name, 206 | real_name = column.real_name, 207 | ) 208 | }) 209 | .collect::>() 210 | .join(", "); 211 | 212 | let query = format!( 213 | r#" 214 | CREATE OR REPLACE FUNCTION {trigger_name}() 215 | RETURNS TRIGGER AS $$ 216 | #variable_conflict use_variable 217 | BEGIN 218 | IF reshape.is_new_schema() THEN 219 | DECLARE 220 | {from_table} record; 221 | BEGIN 222 | SELECT {into_variables} 223 | INTO {from_table}; 224 | 225 | {maybe_null_check} 226 | 227 | -- Don't trigger reverse trigger when making this update 228 | perform set_config('reshape.disable_triggers', 'TRUE', TRUE); 229 | 230 | UPDATE "migration_{existing_schema_name}"."{changed_table}" "{changed_table}" 231 | SET "{column_name}" = {value} 232 | WHERE {where}; 233 | 234 | perform set_config('reshape.disable_triggers', '', TRUE); 235 | END; 236 | END IF; 237 | RETURN NEW; 238 | END 239 | $$ language 'plpgsql'; 240 | 241 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}"; 242 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 243 | "#, 244 | changed_table = self.table, 245 | from_table = from_table.name, 246 | from_table_real = from_table.real_name, 247 | column_name = self.column, 248 | trigger_name = self.trigger_name(ctx), 249 | ); 250 | db.run(&query).context("failed to create down trigger")?; 251 | 252 | let changed_into_variables = table 253 | .columns 254 | .iter() 255 | .map(|column| { 256 | format!( 257 | "NEW.{real_name} AS {alias}", 258 | alias = column.name, 259 | real_name = column.real_name, 260 | ) 261 | }) 262 | .collect::>() 263 | .join(", "); 264 | 265 | let from_table_columns = from_table 266 | .columns 267 | .iter() 268 | .map(|column| format!("{} as {}", column.real_name, column.name)) 269 | .collect::>() 270 | .join(", "); 271 | 272 | let query = format!( 273 | r#" 274 | CREATE OR REPLACE FUNCTION {trigger_name}() 275 | RETURNS TRIGGER AS $$ 276 | #variable_conflict use_variable 277 | BEGIN 278 | IF reshape.is_new_schema() AND NOT current_setting('reshape.disable_triggers', TRUE) = 'TRUE' THEN 279 | DECLARE 280 | {changed_table} record; 281 | __temp_row record; 282 | BEGIN 283 | SELECT {changed_into_variables} 284 | INTO {changed_table}; 285 | 286 | SELECT * 287 | INTO __temp_row 288 | FROM ( 289 | SELECT {from_table_columns} 290 | FROM public.{from_table_real} 291 | ) {from_table} 292 | WHERE {where}; 293 | 294 | DECLARE 295 | {from_table} record; 296 | BEGIN 297 | {from_table} := __temp_row; 298 | NEW.{column_name_real} = {value}; 299 | END; 300 | END; 301 | END IF; 302 | RETURN NEW; 303 | END 304 | $$ language 'plpgsql'; 305 | 306 | DROP TRIGGER IF EXISTS "{trigger_name}" ON "{changed_table_real}"; 307 | CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{changed_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); 308 | "#, 309 | changed_table = table.name, 310 | changed_table_real = table.real_name, 311 | from_table = from_table.name, 312 | from_table_real = from_table.real_name, 313 | column_name_real = column.real_name, 314 | trigger_name = self.reverse_trigger_name(ctx), 315 | // declarations = declarations.join("\n"), 316 | ); 317 | db.run(&query) 318 | .context("failed to create reverse down trigger")?; 319 | } 320 | } 321 | 322 | Ok(()) 323 | } 324 | 325 | fn complete<'a>( 326 | &self, 327 | ctx: &MigrationContext, 328 | db: &'a mut dyn Conn, 329 | ) -> anyhow::Result>> { 330 | let indices = common::get_indices_for_column(db, &self.table, &self.column) 331 | .context("failed getting column indices")?; 332 | 333 | for index in indices { 334 | db.run(&format!( 335 | " 336 | DROP INDEX CONCURRENTLY IF EXISTS {name} 337 | ", 338 | name = index.name, 339 | )) 340 | .context("failed to drop index")?; 341 | } 342 | 343 | // Remove column, function and trigger 344 | let query = format!( 345 | r#" 346 | ALTER TABLE "{table}" 347 | DROP COLUMN IF EXISTS "{column}"; 348 | 349 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 350 | DROP FUNCTION IF EXISTS "{reverse_trigger_name}" CASCADE; 351 | DROP FUNCTION IF EXISTS "{null_trigger_name}" CASCADE; 352 | "#, 353 | table = self.table, 354 | column = self.column, 355 | trigger_name = self.trigger_name(ctx), 356 | reverse_trigger_name = self.reverse_trigger_name(ctx), 357 | null_trigger_name = self.not_null_constraint_trigger_name(ctx), 358 | ); 359 | db.run(&query) 360 | .context("failed to drop column and down trigger")?; 361 | 362 | Ok(None) 363 | } 364 | 365 | fn update_schema(&self, _ctx: &MigrationContext, schema: &mut Schema) { 366 | schema.change_table(&self.table, |table_changes| { 367 | table_changes.change_column(&self.column, |column_changes| { 368 | column_changes.set_removed(); 369 | }) 370 | }); 371 | } 372 | 373 | fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { 374 | // We might have temporaily removed the NOT NULL check and have to reinstate it 375 | let has_not_null_function = !db 376 | .query_with_params( 377 | " 378 | SELECT routine_name 379 | FROM information_schema.routines 380 | WHERE routine_schema = 'public' 381 | AND routine_name = $1 382 | ", 383 | &[&self.not_null_constraint_trigger_name(ctx)], 384 | ) 385 | .context("failed to get any NOT NULL function")? 386 | .is_empty(); 387 | 388 | if has_not_null_function { 389 | // Make column NOT NULL again without taking any long lived locks with a temporary constraint 390 | let query = format!( 391 | r#" 392 | ALTER TABLE "{table}" 393 | ADD CONSTRAINT "{constraint_name}" 394 | CHECK ("{column}" IS NOT NULL) NOT VALID 395 | "#, 396 | table = self.table, 397 | constraint_name = self.not_null_constraint_name(ctx), 398 | column = self.column, 399 | ); 400 | db.run(&query) 401 | .context("failed to add NOT NULL constraint")?; 402 | 403 | let query = format!( 404 | r#" 405 | ALTER TABLE "{table}" 406 | VALIDATE CONSTRAINT "{constraint_name}" 407 | "#, 408 | table = self.table, 409 | constraint_name = self.not_null_constraint_name(ctx), 410 | ); 411 | db.run(&query) 412 | .context("failed to validate NOT NULL constraint")?; 413 | 414 | // This ALTER TABLE call will not require any exclusive locks as it can use the validated constraint from above 415 | db.run(&format!( 416 | r#" 417 | ALTER TABLE {table} 418 | ALTER COLUMN {column} 419 | SET NOT NULL 420 | "#, 421 | table = self.table, 422 | column = self.column 423 | )) 424 | .context("failed to reinstate column NOT NULL")?; 425 | 426 | // Drop the temporary constraint 427 | let query = format!( 428 | r#" 429 | ALTER TABLE "{table}" 430 | DROP CONSTRAINT "{constraint_name}" 431 | "#, 432 | table = self.table, 433 | constraint_name = self.not_null_constraint_name(ctx), 434 | ); 435 | db.run(&query) 436 | .context("failed to drop NOT NULL constraint")?; 437 | } 438 | 439 | // Remove function and trigger 440 | db.run(&format!( 441 | r#" 442 | DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; 443 | DROP FUNCTION IF EXISTS "{reverse_trigger_name}" CASCADE; 444 | DROP FUNCTION IF EXISTS "{null_trigger_name}" CASCADE; 445 | "#, 446 | trigger_name = self.trigger_name(ctx), 447 | reverse_trigger_name = self.reverse_trigger_name(ctx), 448 | null_trigger_name = self.not_null_constraint_trigger_name(ctx), 449 | )) 450 | .context("failed to drop down trigger")?; 451 | 452 | Ok(()) 453 | } 454 | } 455 | -------------------------------------------------------------------------------- /src/migrations/remove_enum.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RemoveEnum { 11 | #[serde(rename = "enum")] 12 | pub enum_name: String, 13 | } 14 | 15 | #[typetag::serde(name = "remove_enum")] 16 | impl Action for RemoveEnum { 17 | fn describe(&self) -> String { 18 | format!("Removing enum \"{}\"", self.enum_name) 19 | } 20 | 21 | fn run( 22 | &self, 23 | _ctx: &MigrationContext, 24 | _db: &mut dyn Conn, 25 | _schema: &Schema, 26 | ) -> anyhow::Result<()> { 27 | Ok(()) 28 | } 29 | 30 | fn complete<'a>( 31 | &self, 32 | _ctx: &MigrationContext, 33 | db: &'a mut dyn Conn, 34 | ) -> anyhow::Result>> { 35 | db.run(&format!( 36 | r#" 37 | DROP TYPE IF EXISTS {name} 38 | "#, 39 | name = self.enum_name, 40 | )) 41 | .context("failed to drop enum")?; 42 | 43 | Ok(None) 44 | } 45 | 46 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 47 | 48 | fn abort(&self, _ctx: &MigrationContext, _db: &mut dyn Conn) -> anyhow::Result<()> { 49 | Ok(()) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/migrations/remove_foreign_key.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::{anyhow, Context}; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RemoveForeignKey { 11 | table: String, 12 | foreign_key: String, 13 | } 14 | 15 | #[typetag::serde(name = "remove_foreign_key")] 16 | impl Action for RemoveForeignKey { 17 | fn describe(&self) -> String { 18 | format!( 19 | "Removing foreign key \"{}\" from table \"{}\"", 20 | self.foreign_key, self.table 21 | ) 22 | } 23 | 24 | fn run( 25 | &self, 26 | _ctx: &MigrationContext, 27 | db: &mut dyn Conn, 28 | schema: &Schema, 29 | ) -> anyhow::Result<()> { 30 | // The foreign key is only removed once the migration is completed. 31 | // Removing it earlier would be hard/undesirable for several reasons: 32 | // - Postgres doesn't have an easy way to temporarily disable a foreign key check. 33 | // If it did, we could disable the FK for the new schema. 34 | // - Even if we could, it probably wouldn't be a good idea as it would cause temporary 35 | // inconsistencies for the old schema which still expects the FK to hold. 36 | // - For the same reason, we can't remove the FK when the migration is first applied. 37 | // If the migration was to be aborted, then the FK would have to be recreated with 38 | // the risk that it would no longer be valid. 39 | 40 | // Ensure foreign key exists 41 | let table = schema.get_table(db, &self.table)?; 42 | let fk_exists = !db 43 | .query(&format!( 44 | r#" 45 | SELECT constraint_name 46 | FROM information_schema.table_constraints 47 | WHERE 48 | constraint_type = 'FOREIGN KEY' AND 49 | table_name = '{table_name}' AND 50 | constraint_name = '{foreign_key}' 51 | "#, 52 | table_name = table.real_name, 53 | foreign_key = self.foreign_key, 54 | )) 55 | .context("failed to check for foreign key")? 56 | .is_empty(); 57 | 58 | if !fk_exists { 59 | return Err(anyhow!( 60 | "no foreign key \"{}\" exists on table \"{}\"", 61 | self.foreign_key, 62 | self.table 63 | )); 64 | } 65 | 66 | Ok(()) 67 | } 68 | 69 | fn complete<'a>( 70 | &self, 71 | _ctx: &MigrationContext, 72 | db: &'a mut dyn Conn, 73 | ) -> anyhow::Result>> { 74 | db.run(&format!( 75 | r#" 76 | ALTER TABLE {table} 77 | DROP CONSTRAINT IF EXISTS {foreign_key} 78 | "#, 79 | table = self.table, 80 | foreign_key = self.foreign_key, 81 | )) 82 | .context("failed to remove foreign key")?; 83 | Ok(None) 84 | } 85 | 86 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 87 | 88 | fn abort(&self, _ctx: &MigrationContext, _db: &mut dyn Conn) -> anyhow::Result<()> { 89 | Ok(()) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/migrations/remove_index.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RemoveIndex { 11 | pub index: String, 12 | } 13 | 14 | #[typetag::serde(name = "remove_index")] 15 | impl Action for RemoveIndex { 16 | fn describe(&self) -> String { 17 | format!("Removing index \"{}\"", self.index) 18 | } 19 | 20 | fn run( 21 | &self, 22 | _ctx: &MigrationContext, 23 | _db: &mut dyn Conn, 24 | _schema: &Schema, 25 | ) -> anyhow::Result<()> { 26 | // Do nothing, the index isn't removed until completion 27 | Ok(()) 28 | } 29 | 30 | fn complete<'a>( 31 | &self, 32 | _ctx: &MigrationContext, 33 | db: &'a mut dyn Conn, 34 | ) -> anyhow::Result>> { 35 | db.run(&format!( 36 | r#" 37 | DROP INDEX CONCURRENTLY IF EXISTS "{name}" 38 | "#, 39 | name = self.index 40 | )) 41 | .context("failed to drop index")?; 42 | 43 | Ok(None) 44 | } 45 | 46 | fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} 47 | 48 | fn abort(&self, _ctx: &MigrationContext, _db: &mut dyn Conn) -> anyhow::Result<()> { 49 | Ok(()) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/migrations/remove_table.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RemoveTable { 11 | pub table: String, 12 | } 13 | 14 | #[typetag::serde(name = "remove_table")] 15 | impl Action for RemoveTable { 16 | fn describe(&self) -> String { 17 | format!("Removing table \"{}\"", self.table) 18 | } 19 | 20 | fn run( 21 | &self, 22 | _ctx: &MigrationContext, 23 | _db: &mut dyn Conn, 24 | _schema: &Schema, 25 | ) -> anyhow::Result<()> { 26 | Ok(()) 27 | } 28 | 29 | fn complete<'a>( 30 | &self, 31 | _ctx: &MigrationContext, 32 | db: &'a mut dyn Conn, 33 | ) -> anyhow::Result>> { 34 | // Remove table 35 | let query = format!( 36 | r#" 37 | DROP TABLE IF EXISTS "{table}"; 38 | "#, 39 | table = self.table, 40 | ); 41 | db.run(&query).context("failed to drop table")?; 42 | 43 | Ok(None) 44 | } 45 | 46 | fn update_schema(&self, _ctx: &MigrationContext, schema: &mut Schema) { 47 | schema.change_table(&self.table, |table_changes| { 48 | table_changes.set_removed(); 49 | }); 50 | } 51 | 52 | fn abort(&self, _ctx: &MigrationContext, _db: &mut dyn Conn) -> anyhow::Result<()> { 53 | Ok(()) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/migrations/rename_table.rs: -------------------------------------------------------------------------------- 1 | use super::{Action, MigrationContext}; 2 | use crate::{ 3 | db::{Conn, Transaction}, 4 | schema::Schema, 5 | }; 6 | use anyhow::Context; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct RenameTable { 11 | pub table: String, 12 | pub new_name: String, 13 | } 14 | 15 | #[typetag::serde(name = "rename_table")] 16 | impl Action for RenameTable { 17 | fn describe(&self) -> String { 18 | format!("Renaming table \"{}\" to \"{}\"", self.table, self.new_name) 19 | } 20 | 21 | fn run( 22 | &self, 23 | _ctx: &MigrationContext, 24 | _db: &mut dyn Conn, 25 | _schema: &Schema, 26 | ) -> anyhow::Result<()> { 27 | Ok(()) 28 | } 29 | 30 | fn complete<'a>( 31 | &self, 32 | _ctx: &MigrationContext, 33 | db: &'a mut dyn Conn, 34 | ) -> anyhow::Result>> { 35 | // Rename table 36 | let query = format!( 37 | r#" 38 | ALTER TABLE IF EXISTS "{table}" 39 | RENAME TO "{new_name}" 40 | "#, 41 | table = self.table, 42 | new_name = self.new_name, 43 | ); 44 | db.run(&query).context("failed to rename table")?; 45 | 46 | Ok(None) 47 | } 48 | 49 | fn update_schema(&self, _ctx: &MigrationContext, schema: &mut Schema) { 50 | schema.change_table(&self.table, |table_changes| { 51 | table_changes.set_name(&self.new_name); 52 | }); 53 | } 54 | 55 | fn abort(&self, _ctx: &MigrationContext, _db: &mut dyn Conn) -> anyhow::Result<()> { 56 | Ok(()) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/schema.rs: -------------------------------------------------------------------------------- 1 | use crate::db::Conn; 2 | use std::collections::{HashMap, HashSet}; 3 | 4 | // Schema tracks changes made to tables and columns during a migration. 5 | // These changes are not applied until the migration is completed but 6 | // need to be taken into consideration when creating views for a migration 7 | // and when a user references a table or column in a migration. 8 | // 9 | // The changes to a table are tracked by a `TableChanges` struct. The possible 10 | // changes are: 11 | // - Changing the name which updates `current_name`. 12 | // - Removing which sets the `removed` flag. 13 | // 14 | // Changes to a column are tracked by a `ColumnChanges` struct which reside in 15 | // the corresponding `TableChanges`. The possible changes are: 16 | // - Changing the name which updates `current_name`. 17 | // - Changing the backing column which will add the new column to the end of 18 | // `intermediate_columns`. This is used when temporary columns are 19 | // introduced which will eventually replace the current column. 20 | // - Removing which sets the `removed` flag. 21 | // 22 | // Schema provides some schema introspection methods, `get_tables` and `get_table`, 23 | // which will retrieve the current schema from the database and apply the changes. 24 | #[derive(Debug)] 25 | pub struct Schema { 26 | table_changes: Vec, 27 | } 28 | 29 | impl Schema { 30 | pub fn new() -> Schema { 31 | Schema { 32 | table_changes: Vec::new(), 33 | } 34 | } 35 | 36 | pub fn change_table(&mut self, current_name: &str, f: F) 37 | where 38 | F: FnOnce(&mut TableChanges), 39 | { 40 | let table_change_index = self 41 | .table_changes 42 | .iter() 43 | .position(|table| table.current_name == current_name) 44 | .unwrap_or_else(|| { 45 | let new_changes = TableChanges::new(current_name.to_string()); 46 | self.table_changes.push(new_changes); 47 | self.table_changes.len() - 1 48 | }); 49 | 50 | let table_changes = &mut self.table_changes[table_change_index]; 51 | f(table_changes) 52 | } 53 | } 54 | 55 | impl Default for Schema { 56 | fn default() -> Self { 57 | Self::new() 58 | } 59 | } 60 | 61 | #[derive(Debug)] 62 | pub struct TableChanges { 63 | current_name: String, 64 | real_name: String, 65 | column_changes: Vec, 66 | removed: bool, 67 | } 68 | 69 | impl TableChanges { 70 | fn new(name: String) -> Self { 71 | Self { 72 | current_name: name.to_string(), 73 | real_name: name, 74 | column_changes: Vec::new(), 75 | removed: false, 76 | } 77 | } 78 | 79 | pub fn set_name(&mut self, name: &str) { 80 | self.current_name = name.to_string(); 81 | } 82 | 83 | pub fn change_column(&mut self, current_name: &str, f: F) 84 | where 85 | F: FnOnce(&mut ColumnChanges), 86 | { 87 | let column_change_index = self 88 | .column_changes 89 | .iter() 90 | .position(|column| column.current_name == current_name) 91 | .unwrap_or_else(|| { 92 | let new_changes = ColumnChanges::new(current_name.to_string()); 93 | self.column_changes.push(new_changes); 94 | self.column_changes.len() - 1 95 | }); 96 | 97 | let column_changes = &mut self.column_changes[column_change_index]; 98 | f(column_changes) 99 | } 100 | 101 | pub fn set_removed(&mut self) { 102 | self.removed = true; 103 | } 104 | } 105 | 106 | #[derive(Debug)] 107 | pub struct ColumnChanges { 108 | current_name: String, 109 | backing_columns: Vec, 110 | removed: bool, 111 | } 112 | 113 | impl ColumnChanges { 114 | fn new(name: String) -> Self { 115 | Self { 116 | current_name: name.to_string(), 117 | backing_columns: vec![name], 118 | removed: false, 119 | } 120 | } 121 | 122 | pub fn set_name(&mut self, name: &str) { 123 | self.current_name = name.to_string(); 124 | } 125 | 126 | pub fn set_column(&mut self, column_name: &str) { 127 | self.backing_columns.push(column_name.to_string()) 128 | } 129 | 130 | pub fn set_removed(&mut self) { 131 | self.removed = true; 132 | } 133 | 134 | fn real_name(&self) -> &str { 135 | self.backing_columns 136 | .last() 137 | .expect("backing_columns should never be empty") 138 | } 139 | } 140 | 141 | #[derive(Debug)] 142 | pub struct Table { 143 | pub name: String, 144 | pub real_name: String, 145 | pub columns: Vec, 146 | } 147 | 148 | #[derive(Debug)] 149 | pub struct Column { 150 | pub name: String, 151 | pub real_name: String, 152 | pub data_type: String, 153 | pub nullable: bool, 154 | pub default: Option, 155 | } 156 | 157 | impl Schema { 158 | pub fn get_tables(&self, db: &mut dyn Conn) -> anyhow::Result> { 159 | db.query( 160 | " 161 | SELECT table_name 162 | FROM information_schema.tables 163 | WHERE table_schema = 'public' 164 | ", 165 | )? 166 | .iter() 167 | .map(|row| row.get::<'_, _, String>("table_name")) 168 | .filter_map(|real_name| { 169 | let table_changes = self 170 | .table_changes 171 | .iter() 172 | .find(|changes| changes.real_name == real_name); 173 | 174 | // Skip table if it has been removed 175 | if let Some(changes) = table_changes { 176 | if changes.removed { 177 | return None; 178 | } 179 | } 180 | 181 | Some(self.get_table_by_real_name(db, &real_name)) 182 | }) 183 | .collect() 184 | } 185 | 186 | pub fn get_table(&self, db: &mut dyn Conn, table_name: &str) -> anyhow::Result { 187 | let table_changes = self 188 | .table_changes 189 | .iter() 190 | .find(|changes| changes.current_name == table_name); 191 | 192 | let real_table_name = table_changes 193 | .map(|changes| changes.real_name.to_string()) 194 | .unwrap_or_else(|| table_name.to_string()); 195 | 196 | self.get_table_by_real_name(db, &real_table_name) 197 | } 198 | 199 | fn get_table_by_real_name( 200 | &self, 201 | db: &mut dyn Conn, 202 | real_table_name: &str, 203 | ) -> anyhow::Result
{ 204 | let table_changes = self 205 | .table_changes 206 | .iter() 207 | .find(|changes| changes.real_name == real_table_name); 208 | 209 | let real_columns: Vec<(String, String, bool, Option)> = db 210 | .query(&format!( 211 | " 212 | SELECT column_name, CASE WHEN data_type = 'USER-DEFINED' THEN udt_name ELSE data_type END, is_nullable, column_default 213 | FROM information_schema.columns 214 | WHERE table_name = '{table}' AND table_schema = 'public' 215 | ORDER BY ordinal_position 216 | ", 217 | table = real_table_name, 218 | ))? 219 | .iter() 220 | .map(|row| { 221 | ( 222 | row.get("column_name"), 223 | row.get("data_type"), 224 | row.get::<'_, _, String>("is_nullable") == "YES", 225 | row.get("column_default"), 226 | ) 227 | }) 228 | .collect(); 229 | 230 | let mut ignore_columns: HashSet = HashSet::new(); 231 | let mut aliases: HashMap = HashMap::new(); 232 | 233 | if let Some(changes) = table_changes { 234 | for column_changes in &changes.column_changes { 235 | if column_changes.removed { 236 | ignore_columns.insert(column_changes.real_name().to_string()); 237 | } else { 238 | aliases.insert( 239 | column_changes.real_name().to_string(), 240 | &column_changes.current_name, 241 | ); 242 | } 243 | 244 | let (_, rest) = column_changes 245 | .backing_columns 246 | .split_last() 247 | .expect("backing_columns should never be empty"); 248 | 249 | for column in rest { 250 | ignore_columns.insert(column.to_string()); 251 | } 252 | } 253 | } 254 | 255 | let mut columns: Vec = Vec::new(); 256 | 257 | for (real_name, data_type, nullable, default) in real_columns { 258 | if ignore_columns.contains(&*real_name) { 259 | continue; 260 | } 261 | 262 | let name = aliases 263 | .get(&real_name) 264 | .map(|alias| alias.to_string()) 265 | .unwrap_or_else(|| real_name.to_string()); 266 | 267 | columns.push(Column { 268 | name, 269 | real_name, 270 | data_type, 271 | nullable, 272 | default, 273 | }); 274 | } 275 | 276 | let current_table_name = table_changes 277 | .map(|changes| changes.current_name.as_ref()) 278 | .unwrap_or_else(|| real_table_name); 279 | 280 | let table = Table { 281 | name: current_table_name.to_string(), 282 | real_name: real_table_name.to_string(), 283 | columns, 284 | }; 285 | 286 | Ok(table) 287 | } 288 | } 289 | 290 | impl Table { 291 | pub fn real_column_names<'a>( 292 | &'a self, 293 | columns: &'a [String], 294 | ) -> impl Iterator { 295 | columns.iter().map(|name| { 296 | self.get_column(name) 297 | .map(|col| &col.real_name) 298 | .unwrap_or(name) 299 | }) 300 | } 301 | 302 | pub fn get_column(&self, name: &str) -> Option<&Column> { 303 | self.columns.iter().find(|column| column.name == name) 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /src/state.rs: -------------------------------------------------------------------------------- 1 | use crate::{db::Conn, migrations::Migration}; 2 | use anyhow::anyhow; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | use version::version; 6 | 7 | #[derive(Serialize, Deserialize, Clone, Debug)] 8 | #[serde(tag = "state")] 9 | pub enum State { 10 | #[serde(rename = "idle")] 11 | Idle, 12 | 13 | #[serde(rename = "applying")] 14 | Applying { migrations: Vec }, 15 | 16 | #[serde(rename = "in_progress")] 17 | InProgress { migrations: Vec }, 18 | 19 | #[serde(rename = "completing")] 20 | Completing { 21 | migrations: Vec, 22 | current_migration_index: usize, 23 | current_action_index: usize, 24 | }, 25 | 26 | #[serde(rename = "aborting")] 27 | Aborting { 28 | migrations: Vec, 29 | last_migration_index: usize, 30 | last_action_index: usize, 31 | }, 32 | } 33 | 34 | impl State { 35 | pub fn load(db: &mut impl Conn) -> anyhow::Result { 36 | Self::ensure_schema_and_table(db)?; 37 | 38 | let results = db.query("SELECT value FROM reshape.data WHERE key = 'state'")?; 39 | 40 | let state = match results.first() { 41 | Some(row) => { 42 | let json: serde_json::Value = row.get(0); 43 | serde_json::from_value(json)? 44 | } 45 | None => Default::default(), 46 | }; 47 | Ok(state) 48 | } 49 | 50 | pub fn save(&self, db: &mut impl Conn) -> anyhow::Result<()> { 51 | Self::ensure_schema_and_table(db)?; 52 | 53 | let json = serde_json::to_value(self)?; 54 | db.query_with_params( 55 | "INSERT INTO reshape.data (key, value) VALUES ('state', $1) ON CONFLICT (key) DO UPDATE SET value = $1", 56 | &[&json] 57 | )?; 58 | Ok(()) 59 | } 60 | 61 | pub fn clear(&mut self, db: &mut impl Conn) -> anyhow::Result<()> { 62 | db.run("DROP SCHEMA reshape CASCADE")?; 63 | 64 | *self = Self::default(); 65 | 66 | Ok(()) 67 | } 68 | 69 | // Complete will change the state from Completing to Idle 70 | pub fn complete(&mut self, db: &mut impl Conn) -> anyhow::Result<()> { 71 | let current_state = std::mem::replace(self, Self::Idle); 72 | 73 | match current_state { 74 | Self::Completing { migrations, .. } => { 75 | // Add migrations and update state in a transaction to ensure atomicity 76 | let mut transaction = db.transaction()?; 77 | save_migrations(&mut transaction, &migrations)?; 78 | self.save(&mut transaction)?; 79 | transaction.commit()?; 80 | } 81 | _ => { 82 | // Move old state back 83 | *self = current_state; 84 | 85 | return Err(anyhow!( 86 | "couldn't update state to be completed, not in Completing state" 87 | )); 88 | } 89 | } 90 | 91 | Ok(()) 92 | } 93 | 94 | pub fn applying(&mut self, new_migrations: Vec) { 95 | *self = Self::Applying { 96 | migrations: new_migrations, 97 | }; 98 | } 99 | 100 | pub fn in_progress(&mut self, new_migrations: Vec) { 101 | *self = Self::InProgress { 102 | migrations: new_migrations, 103 | }; 104 | } 105 | 106 | pub fn completing( 107 | &mut self, 108 | migrations: Vec, 109 | current_migration_index: usize, 110 | current_action_index: usize, 111 | ) { 112 | *self = Self::Completing { 113 | migrations, 114 | current_migration_index, 115 | current_action_index, 116 | } 117 | } 118 | 119 | pub fn aborting( 120 | &mut self, 121 | migrations: Vec, 122 | last_migration_index: usize, 123 | last_action_index: usize, 124 | ) { 125 | *self = Self::Aborting { 126 | migrations, 127 | last_migration_index, 128 | last_action_index, 129 | } 130 | } 131 | 132 | fn ensure_schema_and_table(db: &mut impl Conn) -> anyhow::Result<()> { 133 | db.run("CREATE SCHEMA IF NOT EXISTS reshape")?; 134 | 135 | // Create data table which will be a key-value table containing 136 | // the version and current state. 137 | db.run("CREATE TABLE IF NOT EXISTS reshape.data (key TEXT PRIMARY KEY, value JSONB)")?; 138 | 139 | // Create migrations table which will store all completed migrations 140 | db.run( 141 | " 142 | CREATE TABLE IF NOT EXISTS reshape.migrations ( 143 | index INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY, 144 | name TEXT NOT NULL, 145 | description TEXT, 146 | actions JSONB NOT NULL, 147 | completed_at TIMESTAMP DEFAULT NOW() 148 | ) 149 | ", 150 | )?; 151 | 152 | // Update the current version 153 | let encoded_version = serde_json::to_value(version!().to_string())?; 154 | db.query_with_params( 155 | " 156 | INSERT INTO reshape.data (key, value) 157 | VALUES ('version', $1) 158 | ON CONFLICT (key) DO UPDATE SET value = $1 159 | ", 160 | &[&encoded_version], 161 | )?; 162 | 163 | Ok(()) 164 | } 165 | } 166 | 167 | impl Default for State { 168 | fn default() -> Self { 169 | Self::Idle 170 | } 171 | } 172 | 173 | pub fn current_migration(db: &mut dyn Conn) -> anyhow::Result> { 174 | let name: Option = db 175 | .query( 176 | " 177 | SELECT name 178 | FROM reshape.migrations 179 | ORDER BY index DESC 180 | LIMIT 1 181 | ", 182 | )? 183 | .first() 184 | .map(|row| row.get("name")); 185 | Ok(name) 186 | } 187 | 188 | pub fn remaining_migrations( 189 | db: &mut impl Conn, 190 | new_migrations: impl IntoIterator, 191 | ) -> anyhow::Result> { 192 | let mut new_iter = new_migrations.into_iter(); 193 | 194 | // Ensure the new migrations match up with the existing ones 195 | let mut highest_index: Option = None; 196 | loop { 197 | let migrations = get_migrations(db, highest_index)?; 198 | if migrations.is_empty() { 199 | break; 200 | } 201 | 202 | for (index, existing) in migrations { 203 | highest_index = Some(index); 204 | 205 | let new = match new_iter.next() { 206 | Some(migration) => migration, 207 | None => { 208 | return Err(anyhow!( 209 | "existing migration {} doesn't exist in local migrations", 210 | existing 211 | )) 212 | } 213 | }; 214 | 215 | if existing != new.name { 216 | return Err(anyhow!( 217 | "existing migration {} does not match new migration {}", 218 | existing, 219 | new.name 220 | )); 221 | } 222 | } 223 | } 224 | 225 | // Return the remaining migrations 226 | let items: Vec = new_iter.collect(); 227 | Ok(items) 228 | } 229 | 230 | fn get_migrations( 231 | db: &mut impl Conn, 232 | index_larger_than: Option, 233 | ) -> anyhow::Result> { 234 | let rows = if let Some(index_larger_than) = index_larger_than { 235 | db.query_with_params( 236 | " 237 | SELECT index, name 238 | FROM reshape.migrations 239 | WHERE index > $1 240 | ORDER BY index ASC 241 | LIMIT 100 242 | ", 243 | &[&index_larger_than], 244 | )? 245 | } else { 246 | db.query( 247 | " 248 | SELECT index, name 249 | FROM reshape.migrations 250 | LIMIT 100 251 | ", 252 | )? 253 | }; 254 | 255 | let migrations = rows 256 | .iter() 257 | .map(|row| (row.get("index"), row.get("name"))) 258 | .collect(); 259 | Ok(migrations) 260 | } 261 | 262 | fn save_migrations(db: &mut impl Conn, migrations: &[Migration]) -> anyhow::Result<()> { 263 | for migration in migrations { 264 | let encoded_actions = serde_json::to_value(&migration.actions)?; 265 | db.query_with_params( 266 | "INSERT INTO reshape.migrations(name, description, actions) VALUES ($1, $2, $3)", 267 | &[&migration.name, &migration.description, &encoded_actions], 268 | )?; 269 | } 270 | 271 | Ok(()) 272 | } 273 | -------------------------------------------------------------------------------- /tests/add_column.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn add_column() { 6 | let mut test = Test::new("Add column"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_user_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions.columns]] 22 | name = "name" 23 | type = "TEXT" 24 | "#, 25 | ); 26 | 27 | test.second_migration( 28 | r#" 29 | name = "add_first_and_last_name_columns" 30 | 31 | [[actions]] 32 | type = "add_column" 33 | table = "users" 34 | 35 | up = "(STRING_TO_ARRAY(name, ' '))[1]" 36 | 37 | [actions.column] 38 | name = "first" 39 | type = "TEXT" 40 | nullable = false 41 | 42 | [[actions]] 43 | type = "add_column" 44 | table = "users" 45 | 46 | up = "(STRING_TO_ARRAY(name, ' '))[2]" 47 | 48 | [actions.column] 49 | name = "last" 50 | type = "TEXT" 51 | nullable = false 52 | "#, 53 | ); 54 | 55 | test.after_first(|db| { 56 | // Insert some test users 57 | db.simple_query( 58 | " 59 | INSERT INTO users (id, name) VALUES 60 | (1, 'John Doe'), 61 | (2, 'Jane Doe'); 62 | ", 63 | ) 64 | .unwrap(); 65 | }); 66 | 67 | test.intermediate(|old_db, new_db| { 68 | // Check that the existing users have the new columns populated 69 | let expected = vec![("John", "Doe"), ("Jane", "Doe")]; 70 | assert!(new_db 71 | .query("SELECT first, last FROM users ORDER BY id", &[],) 72 | .unwrap() 73 | .iter() 74 | .map(|row| (row.get("first"), row.get("last"))) 75 | .eq(expected)); 76 | 77 | // Insert data using old schema and make sure the new columns are populated 78 | old_db 79 | .simple_query("INSERT INTO users (id, name) VALUES (3, 'Test Testsson')") 80 | .unwrap(); 81 | let (first_name, last_name): (String, String) = new_db 82 | .query_one("SELECT first, last from users WHERE id = 3", &[]) 83 | .map(|row| (row.get("first"), row.get("last"))) 84 | .unwrap(); 85 | assert_eq!( 86 | ("Test", "Testsson"), 87 | (first_name.as_ref(), last_name.as_ref()) 88 | ); 89 | }); 90 | 91 | test.after_completion(|db| { 92 | let expected = vec![("John", "Doe"), ("Jane", "Doe"), ("Test", "Testsson")]; 93 | assert!(db 94 | .query("SELECT first, last FROM users ORDER BY id", &[],) 95 | .unwrap() 96 | .iter() 97 | .map(|row| (row.get("first"), row.get("last"))) 98 | .eq(expected)); 99 | }); 100 | 101 | test.after_abort(|db| { 102 | let expected = vec![("John Doe"), ("Jane Doe"), ("Test Testsson")]; 103 | assert!(db 104 | .query("SELECT name FROM users ORDER BY id", &[],) 105 | .unwrap() 106 | .iter() 107 | .map(|row| row.get::<'_, _, String>("name")) 108 | .eq(expected)); 109 | }); 110 | 111 | test.run() 112 | } 113 | 114 | #[test] 115 | fn add_column_nullable() { 116 | let mut test = Test::new("Add nullable column"); 117 | 118 | test.first_migration( 119 | r#" 120 | name = "create_users_table" 121 | 122 | [[actions]] 123 | type = "create_table" 124 | name = "users" 125 | primary_key = ["id"] 126 | 127 | [[actions.columns]] 128 | name = "id" 129 | type = "INTEGER" 130 | "#, 131 | ); 132 | 133 | test.second_migration( 134 | r#" 135 | name = "add_nullable_name_column" 136 | 137 | [[actions]] 138 | type = "add_column" 139 | table = "users" 140 | 141 | [actions.column] 142 | name = "name" 143 | type = "TEXT" 144 | "#, 145 | ); 146 | 147 | test.after_first(|db| { 148 | // Insert some test values 149 | db.simple_query( 150 | " 151 | INSERT INTO users (id) VALUES (1), (2); 152 | ", 153 | ) 154 | .unwrap(); 155 | }); 156 | 157 | test.intermediate(|old_db, new_db| { 158 | // Ensure existing data got updated 159 | let expected: Vec> = vec![None, None]; 160 | assert!(new_db 161 | .query("SELECT name FROM users ORDER BY id", &[],) 162 | .unwrap() 163 | .iter() 164 | .map(|row| row.get::<_, Option>("name")) 165 | .eq(expected)); 166 | 167 | // Insert data using old schema and ensure new column is NULL 168 | old_db 169 | .simple_query("INSERT INTO users (id) VALUES (3)") 170 | .unwrap(); 171 | let name: Option = new_db 172 | .query_one("SELECT name from users WHERE id = 3", &[]) 173 | .map(|row| (row.get("name"))) 174 | .unwrap(); 175 | assert_eq!(None, name); 176 | 177 | // Ensure data can be inserted against new schema 178 | new_db 179 | .simple_query("INSERT INTO users (id, name) VALUES (4, 'Test Testsson'), (5, NULL)") 180 | .unwrap(); 181 | }); 182 | 183 | test.after_completion(|db| { 184 | let expected: Vec> = 185 | vec![None, None, None, Some("Test Testsson".to_string()), None]; 186 | let result: Vec> = db 187 | .query("SELECT id, name FROM users ORDER BY id", &[]) 188 | .unwrap() 189 | .iter() 190 | .map(|row| row.get("name")) 191 | .collect(); 192 | 193 | assert_eq!(result, expected); 194 | }); 195 | 196 | test.run(); 197 | } 198 | 199 | #[test] 200 | fn add_column_with_default() { 201 | let mut test = Test::new("Add column with default value"); 202 | 203 | test.first_migration( 204 | r#" 205 | name = "create_users_table" 206 | 207 | [[actions]] 208 | type = "create_table" 209 | name = "users" 210 | primary_key = ["id"] 211 | 212 | [[actions.columns]] 213 | name = "id" 214 | type = "INTEGER" 215 | "#, 216 | ); 217 | 218 | test.second_migration( 219 | r#" 220 | name = "add_name_column_with_default" 221 | 222 | [[actions]] 223 | type = "add_column" 224 | table = "users" 225 | 226 | [actions.column] 227 | name = "name" 228 | type = "TEXT" 229 | nullable = false 230 | default = "'DEFAULT'" 231 | "#, 232 | ); 233 | 234 | test.after_first(|db| { 235 | // Insert some test values 236 | db.simple_query("INSERT INTO users (id) VALUES (1), (2)") 237 | .unwrap(); 238 | }); 239 | 240 | test.intermediate(|old_db, new_db| { 241 | // Ensure existing data got updated with defaults 242 | let expected = vec!["DEFAULT".to_string(), "DEFAULT".to_string()]; 243 | assert!(new_db 244 | .query("SELECT name FROM users ORDER BY id", &[],) 245 | .unwrap() 246 | .iter() 247 | .map(|row| row.get::<_, String>("name")) 248 | .eq(expected)); 249 | 250 | // Insert data using old schema and ensure new column gets the default value 251 | old_db 252 | .simple_query("INSERT INTO users (id) VALUES (3)") 253 | .unwrap(); 254 | let name: String = new_db 255 | .query_one("SELECT name from users WHERE id = 3", &[]) 256 | .map(|row| row.get("name")) 257 | .unwrap(); 258 | assert_eq!("DEFAULT", name); 259 | }); 260 | 261 | test.run(); 262 | } 263 | 264 | #[test] 265 | fn add_column_with_complex_up() { 266 | let mut test = Test::new("Add column complex"); 267 | 268 | test.first_migration( 269 | r#" 270 | name = "create_tables" 271 | 272 | [[actions]] 273 | type = "create_table" 274 | name = "users" 275 | primary_key = ["id"] 276 | 277 | [[actions.columns]] 278 | name = "id" 279 | type = "INTEGER" 280 | 281 | [[actions.columns]] 282 | name = "email" 283 | type = "TEXT" 284 | 285 | [[actions]] 286 | type = "create_table" 287 | name = "profiles" 288 | primary_key = ["user_id"] 289 | 290 | [[actions.columns]] 291 | name = "user_id" 292 | type = "INTEGER" 293 | "#, 294 | ); 295 | 296 | test.second_migration( 297 | r#" 298 | name = "add_profiles_email_column" 299 | 300 | [[actions]] 301 | type = "add_column" 302 | table = "profiles" 303 | 304 | [actions.column] 305 | name = "email" 306 | type = "TEXT" 307 | nullable = false 308 | 309 | [actions.up] 310 | table = "users" 311 | value = "users.email" 312 | where = "profiles.user_id = users.id" 313 | "#, 314 | ); 315 | 316 | test.after_first(|db| { 317 | db.simple_query("INSERT INTO users (id, email) VALUES (1, 'test@example.com')") 318 | .unwrap(); 319 | db.simple_query("INSERT INTO profiles (user_id) VALUES (1)") 320 | .unwrap(); 321 | }); 322 | 323 | test.intermediate(|old_db, new_db| { 324 | // Ensure email was backfilled on profiles 325 | let email: String = new_db 326 | .query( 327 | " 328 | SELECT email 329 | FROM profiles 330 | WHERE user_id = 1 331 | ", 332 | &[], 333 | ) 334 | .unwrap() 335 | .first() 336 | .map(|row| row.get("email")) 337 | .unwrap(); 338 | assert_eq!("test@example.com", email); 339 | 340 | // Ensure email change in old schema is propagated to profiles table in new schema 341 | old_db 342 | .simple_query("UPDATE users SET email = 'test2@example.com' WHERE id = 1") 343 | .unwrap(); 344 | let email: String = new_db 345 | .query( 346 | " 347 | SELECT email 348 | FROM profiles 349 | WHERE user_id = 1 350 | ", 351 | &[], 352 | ) 353 | .unwrap() 354 | .first() 355 | .map(|row| row.get("email")) 356 | .unwrap(); 357 | assert_eq!("test2@example.com", email); 358 | }); 359 | 360 | test.run(); 361 | } 362 | -------------------------------------------------------------------------------- /tests/add_foreign_key.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn add_foreign_key() { 6 | let mut test = Test::new("Add foreign key"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_user_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions]] 22 | type = "create_table" 23 | name = "items" 24 | primary_key = ["id"] 25 | 26 | [[actions.columns]] 27 | name = "id" 28 | type = "INTEGER" 29 | 30 | [[actions.columns]] 31 | name = "user_id" 32 | type = "INTEGER" 33 | "#, 34 | ); 35 | 36 | test.second_migration( 37 | r#" 38 | name = "add_foreign_key" 39 | 40 | [[actions]] 41 | type = "add_foreign_key" 42 | table = "items" 43 | 44 | [actions.foreign_key] 45 | columns = ["user_id"] 46 | referenced_table = "users" 47 | referenced_columns = ["id"] 48 | "#, 49 | ); 50 | 51 | test.after_first(|db| { 52 | // Insert some test users 53 | db.simple_query("INSERT INTO users (id) VALUES (1), (2)") 54 | .unwrap(); 55 | }); 56 | 57 | test.intermediate(|db, _| { 58 | // Ensure items can be inserted if they reference valid users 59 | db.simple_query("INSERT INTO items (id, user_id) VALUES (1, 1), (2, 2)") 60 | .unwrap(); 61 | 62 | // Ensure items can't be inserted if they don't reference valid users 63 | let result = db.simple_query("INSERT INTO items (id, user_id) VALUES (3, 3)"); 64 | assert!(result.is_err(), "expected insert to fail"); 65 | }); 66 | 67 | test.after_completion(|db| { 68 | // Ensure items can be inserted if they reference valid users 69 | db.simple_query("INSERT INTO items (id, user_id) VALUES (3, 1), (4, 2)") 70 | .unwrap(); 71 | 72 | // Ensure items can't be inserted if they don't reference valid users 73 | let result = db.simple_query("INSERT INTO items (id, user_id) VALUES (5, 3)"); 74 | assert!(result.is_err(), "expected insert to fail"); 75 | 76 | // Ensure foreign key exists with the right name 77 | let foreign_key_name: Option = db 78 | .query( 79 | " 80 | SELECT tc.constraint_name 81 | FROM information_schema.table_constraints AS tc 82 | WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='items'; 83 | ", 84 | &[], 85 | ) 86 | .unwrap() 87 | .first() 88 | .map(|row| row.get(0)); 89 | assert_eq!(Some("items_user_id_fkey".to_string()), foreign_key_name); 90 | }); 91 | 92 | test.after_abort(|db| { 93 | // Ensure foreign key doesn't exist 94 | let fk_does_not_exist = db 95 | .query( 96 | " 97 | SELECT tc.constraint_name 98 | FROM information_schema.table_constraints AS tc 99 | WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='items'; 100 | ", 101 | &[], 102 | ) 103 | .unwrap() 104 | .is_empty(); 105 | assert!(fk_does_not_exist); 106 | }); 107 | 108 | test.run() 109 | } 110 | 111 | #[test] 112 | fn add_invalid_foreign_key() { 113 | let mut test = Test::new("Add invalid foreign key"); 114 | 115 | test.first_migration( 116 | r#" 117 | name = "create_user_table" 118 | 119 | [[actions]] 120 | type = "create_table" 121 | name = "users" 122 | primary_key = ["id"] 123 | 124 | [[actions.columns]] 125 | name = "id" 126 | type = "INTEGER" 127 | 128 | [[actions]] 129 | type = "create_table" 130 | name = "items" 131 | primary_key = ["id"] 132 | 133 | [[actions.columns]] 134 | name = "id" 135 | type = "INTEGER" 136 | 137 | [[actions.columns]] 138 | name = "user_id" 139 | type = "INTEGER" 140 | "#, 141 | ); 142 | 143 | test.second_migration( 144 | r#" 145 | name = "add_foreign_key" 146 | 147 | [[actions]] 148 | type = "add_foreign_key" 149 | table = "items" 150 | 151 | [actions.foreign_key] 152 | columns = ["user_id"] 153 | referenced_table = "users" 154 | referenced_columns = ["id"] 155 | "#, 156 | ); 157 | 158 | test.after_first(|db| { 159 | // Insert some items which don't reference a valid user 160 | db.simple_query("INSERT INTO items (id, user_id) VALUES (1, 1), (2, 2)") 161 | .unwrap(); 162 | }); 163 | 164 | test.expect_failure(); 165 | test.run() 166 | } 167 | -------------------------------------------------------------------------------- /tests/add_index.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn add_index() { 6 | let mut test = Test::new("Add index"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_users_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions.columns]] 22 | name = "name" 23 | type = "TEXT" 24 | "#, 25 | ); 26 | 27 | test.second_migration( 28 | r#" 29 | name = "add_users_name_index" 30 | 31 | [[actions]] 32 | type = "add_index" 33 | table = "users" 34 | 35 | [actions.index] 36 | name = "name_idx" 37 | columns = ["name"] 38 | "#, 39 | ); 40 | 41 | test.intermediate(|db, _| { 42 | // Ensure index is valid and ready 43 | let (is_ready, is_valid): (bool, bool) = db 44 | .query( 45 | " 46 | SELECT pg_index.indisready, pg_index.indisvalid 47 | FROM pg_catalog.pg_index 48 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 49 | WHERE pg_class.relname = 'name_idx' 50 | ", 51 | &[], 52 | ) 53 | .unwrap() 54 | .first() 55 | .map(|row| (row.get("indisready"), row.get("indisvalid"))) 56 | .unwrap(); 57 | 58 | assert!(is_ready, "expected index to be ready"); 59 | assert!(is_valid, "expected index to be valid"); 60 | }); 61 | 62 | test.after_completion(|db| { 63 | // Ensure index is valid and ready 64 | let (is_ready, is_valid): (bool, bool) = db 65 | .query( 66 | " 67 | SELECT pg_index.indisready, pg_index.indisvalid 68 | FROM pg_catalog.pg_index 69 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 70 | WHERE pg_class.relname = 'name_idx' 71 | ", 72 | &[], 73 | ) 74 | .unwrap() 75 | .first() 76 | .map(|row| (row.get("indisready"), row.get("indisvalid"))) 77 | .unwrap(); 78 | 79 | assert!(is_ready, "expected index to be ready"); 80 | assert!(is_valid, "expected index to be valid"); 81 | }); 82 | 83 | test.run(); 84 | } 85 | 86 | #[test] 87 | fn add_index_unique() { 88 | let mut test = Test::new("Add unique index"); 89 | 90 | test.first_migration( 91 | r#" 92 | name = "create_users_table" 93 | 94 | [[actions]] 95 | type = "create_table" 96 | name = "users" 97 | primary_key = ["id"] 98 | 99 | [[actions.columns]] 100 | name = "id" 101 | type = "INTEGER" 102 | 103 | [[actions.columns]] 104 | name = "name" 105 | type = "TEXT" 106 | "#, 107 | ); 108 | 109 | test.second_migration( 110 | r#" 111 | name = "add_name_index" 112 | 113 | [[actions]] 114 | type = "add_index" 115 | table = "users" 116 | 117 | [actions.index] 118 | name = "name_idx" 119 | columns = ["name"] 120 | unique = true 121 | "#, 122 | ); 123 | 124 | test.intermediate(|db, _| { 125 | // Ensure index is valid, ready and unique 126 | let (is_ready, is_valid, is_unique): (bool, bool, bool) = db 127 | .query( 128 | " 129 | SELECT pg_index.indisready, pg_index.indisvalid, pg_index.indisunique 130 | FROM pg_catalog.pg_index 131 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 132 | WHERE pg_class.relname = 'name_idx' 133 | ", 134 | &[], 135 | ) 136 | .unwrap() 137 | .first() 138 | .map(|row| { 139 | ( 140 | row.get("indisready"), 141 | row.get("indisvalid"), 142 | row.get("indisunique"), 143 | ) 144 | }) 145 | .unwrap(); 146 | 147 | assert!(is_ready, "expected index to be ready"); 148 | assert!(is_valid, "expected index to be valid"); 149 | assert!(is_unique, "expected index to be unique"); 150 | }); 151 | 152 | test.run(); 153 | } 154 | 155 | #[test] 156 | fn add_index_with_type() { 157 | let mut test = Test::new("Add GIN index"); 158 | 159 | test.first_migration( 160 | r#" 161 | name = "create_users_table" 162 | 163 | [[actions]] 164 | type = "create_table" 165 | name = "users" 166 | primary_key = ["id"] 167 | 168 | [[actions.columns]] 169 | name = "id" 170 | type = "INTEGER" 171 | 172 | [[actions.columns]] 173 | name = "data" 174 | type = "JSONB" 175 | "#, 176 | ); 177 | 178 | test.second_migration( 179 | r#" 180 | name = "add_data_index" 181 | 182 | [[actions]] 183 | type = "add_index" 184 | table = "users" 185 | 186 | [actions.index] 187 | name = "data_idx" 188 | columns = ["data"] 189 | type = "gin" 190 | "#, 191 | ); 192 | 193 | test.intermediate(|db, _| { 194 | // Ensure index is valid, ready and has the right type 195 | let (is_ready, is_valid, index_type): (bool, bool, String) = db 196 | .query( 197 | " 198 | SELECT pg_index.indisready, pg_index.indisvalid, pg_am.amname 199 | FROM pg_catalog.pg_index 200 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 201 | JOIN pg_catalog.pg_am ON pg_class.relam = pg_am.oid 202 | WHERE pg_class.relname = 'data_idx' 203 | ", 204 | &[], 205 | ) 206 | .unwrap() 207 | .first() 208 | .map(|row| { 209 | ( 210 | row.get("indisready"), 211 | row.get("indisvalid"), 212 | row.get("amname"), 213 | ) 214 | }) 215 | .unwrap(); 216 | 217 | assert!(is_ready, "expected index to be ready"); 218 | assert!(is_valid, "expected index to be valid"); 219 | assert_eq!("gin", index_type, "expected index type to be GIN"); 220 | }); 221 | 222 | test.run(); 223 | } 224 | -------------------------------------------------------------------------------- /tests/common.rs: -------------------------------------------------------------------------------- 1 | use colored::Colorize; 2 | use postgres::{Client, NoTls}; 3 | use reshape::{migrations::Migration, Reshape}; 4 | 5 | pub struct Test<'a> { 6 | name: &'a str, 7 | reshape: Reshape, 8 | old_db: Client, 9 | new_db: Client, 10 | 11 | first_migration: Option, 12 | second_migration: Option, 13 | expect_failure: bool, 14 | 15 | clear_fn: Option ()>, 16 | after_first_fn: Option ()>, 17 | intermediate_fn: Option ()>, 18 | after_completion_fn: Option ()>, 19 | after_abort_fn: Option ()>, 20 | } 21 | 22 | impl Test<'_> { 23 | pub fn new<'a>(name: &'a str) -> Test<'a> { 24 | let connection_string = std::env::var("POSTGRES_CONNECTION_STRING") 25 | .unwrap_or("postgres://postgres:postgres@localhost/reshape_test".to_string()); 26 | 27 | let old_db = Client::connect(&connection_string, NoTls).unwrap(); 28 | let new_db = Client::connect(&connection_string, NoTls).unwrap(); 29 | 30 | let reshape = Reshape::new(&connection_string).unwrap(); 31 | 32 | Test { 33 | name, 34 | reshape, 35 | old_db, 36 | new_db, 37 | first_migration: None, 38 | second_migration: None, 39 | expect_failure: false, 40 | clear_fn: None, 41 | after_first_fn: None, 42 | intermediate_fn: None, 43 | after_completion_fn: None, 44 | after_abort_fn: None, 45 | } 46 | } 47 | 48 | pub fn first_migration(&mut self, migration: &str) -> &mut Self { 49 | self.first_migration = Some(Self::parse_migration(migration)); 50 | self 51 | } 52 | 53 | #[allow(dead_code)] 54 | pub fn second_migration(&mut self, migration: &str) -> &mut Self { 55 | self.second_migration = Some(Self::parse_migration(migration)); 56 | self 57 | } 58 | 59 | #[allow(dead_code)] 60 | pub fn clear(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { 61 | self.clear_fn = Some(f); 62 | self 63 | } 64 | 65 | #[allow(dead_code)] 66 | pub fn after_first(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { 67 | self.after_first_fn = Some(f); 68 | self 69 | } 70 | 71 | #[allow(dead_code)] 72 | pub fn intermediate(&mut self, f: fn(&mut Client, &mut Client) -> ()) -> &mut Self { 73 | self.intermediate_fn = Some(f); 74 | self 75 | } 76 | 77 | #[allow(dead_code)] 78 | pub fn after_completion(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { 79 | self.after_completion_fn = Some(f); 80 | self 81 | } 82 | 83 | #[allow(dead_code)] 84 | pub fn after_abort(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { 85 | self.after_abort_fn = Some(f); 86 | self 87 | } 88 | 89 | #[allow(dead_code)] 90 | pub fn expect_failure(&mut self) { 91 | self.expect_failure = true; 92 | } 93 | 94 | fn parse_migration(encoded: &str) -> Migration { 95 | toml::from_str(encoded).unwrap() 96 | } 97 | } 98 | 99 | enum RunType { 100 | Simple, 101 | Completion, 102 | Abort, 103 | } 104 | 105 | impl Test<'_> { 106 | #[allow(dead_code)] 107 | pub fn run(&mut self) { 108 | if self.second_migration.is_some() { 109 | // Run to completion 110 | print_heading(&format!("Test completion: {}", self.name)); 111 | self.run_internal(RunType::Completion); 112 | 113 | // Run and abort 114 | print_heading(&format!("Test abort: {}", self.name)); 115 | self.run_internal(RunType::Abort); 116 | } else { 117 | print_heading(&format!("Test: {}", self.name)); 118 | self.run_internal(RunType::Simple); 119 | } 120 | } 121 | 122 | fn run_internal(&mut self, run_type: RunType) { 123 | print_subheading("Clearing database"); 124 | self.reshape.remove().unwrap(); 125 | 126 | if let Some(clear_fn) = self.clear_fn { 127 | clear_fn(&mut self.old_db); 128 | } 129 | 130 | // Apply first migration, will automatically complete 131 | print_subheading("Applying first migration"); 132 | let first_migration = self 133 | .first_migration 134 | .as_ref() 135 | .expect("no starting migration set"); 136 | self.reshape.migrate(vec![first_migration.clone()]).unwrap(); 137 | 138 | // Update search path 139 | self.old_db 140 | .simple_query(&reshape::schema_query_for_migration(&first_migration.name)) 141 | .unwrap(); 142 | 143 | // Automatically complete first migration 144 | self.reshape.complete().unwrap(); 145 | 146 | // Run setup function 147 | if let Some(after_first_fn) = self.after_first_fn { 148 | print_subheading("Running setup and first checks"); 149 | after_first_fn(&mut self.old_db); 150 | print_success(); 151 | } 152 | 153 | // Apply second migration 154 | if let Some(second_migration) = &self.second_migration { 155 | if self.expect_failure { 156 | print_subheading("Applying second migration (expecting failure)"); 157 | let result = self 158 | .reshape 159 | .migrate(vec![first_migration.clone(), second_migration.clone()]); 160 | 161 | if result.is_ok() { 162 | panic!("expected second migration to fail"); 163 | } 164 | } else { 165 | print_subheading("Applying second migration"); 166 | self.reshape 167 | .migrate(vec![first_migration.clone(), second_migration.clone()]) 168 | .unwrap(); 169 | } 170 | 171 | // Update search path 172 | self.new_db 173 | .simple_query(&reshape::schema_query_for_migration(&second_migration.name)) 174 | .unwrap(); 175 | 176 | if let Some(intermediate_fn) = self.intermediate_fn { 177 | print_subheading("Running intermediate checks"); 178 | intermediate_fn(&mut self.old_db, &mut self.new_db); 179 | print_success(); 180 | } 181 | 182 | match run_type { 183 | RunType::Completion => { 184 | print_subheading("Completing"); 185 | self.reshape.complete().unwrap(); 186 | 187 | if let Some(after_completion_fn) = self.after_completion_fn { 188 | print_subheading("Running post-completion checks"); 189 | after_completion_fn(&mut self.new_db); 190 | print_success(); 191 | } 192 | } 193 | RunType::Abort => { 194 | print_subheading("Aborting"); 195 | self.reshape.abort().unwrap(); 196 | 197 | if let Some(after_abort_fn) = self.after_abort_fn { 198 | print_subheading("Running post-abort checks"); 199 | after_abort_fn(&mut self.old_db); 200 | print_success(); 201 | } 202 | } 203 | _ => {} 204 | } 205 | } 206 | 207 | print_subheading("Checking cleanup"); 208 | assert_cleaned_up(&mut self.new_db); 209 | print_success(); 210 | } 211 | } 212 | 213 | fn print_heading(text: &str) { 214 | let delimiter = std::iter::repeat("=").take(80).collect::(); 215 | 216 | println!(); 217 | println!(); 218 | println!("{}", delimiter.blue().bold()); 219 | println!("{}", add_spacer(text, "=").blue().bold()); 220 | println!("{}", delimiter.blue().bold()); 221 | } 222 | 223 | fn print_subheading(text: &str) { 224 | println!(); 225 | println!("{}", add_spacer(text, "=").blue()); 226 | } 227 | 228 | fn print_success() { 229 | println!("{}", add_spacer("Success", "=").green()); 230 | } 231 | 232 | fn add_spacer(text: &str, char: &str) -> String { 233 | const TARGET_WIDTH: usize = 80; 234 | let num_of_chars = (TARGET_WIDTH - text.len() - 2) / 2; 235 | let spacer = std::iter::repeat(char) 236 | .take(num_of_chars) 237 | .collect::(); 238 | 239 | let extra = if text.len() % 2 == 0 { "" } else { char }; 240 | 241 | format!("{spacer} {text} {spacer}{extra}", spacer = spacer) 242 | } 243 | 244 | pub fn assert_cleaned_up(db: &mut Client) { 245 | // Make sure no temporary columns remain 246 | let temp_columns: Vec = db 247 | .query( 248 | " 249 | SELECT column_name 250 | FROM information_schema.columns 251 | WHERE table_schema = 'public' 252 | AND column_name LIKE '__reshape%' 253 | ", 254 | &[], 255 | ) 256 | .unwrap() 257 | .iter() 258 | .map(|row| row.get(0)) 259 | .collect(); 260 | 261 | assert!( 262 | temp_columns.is_empty(), 263 | "expected no temporary columns to exist, found: {}", 264 | temp_columns.join(", ") 265 | ); 266 | 267 | // Make sure no triggers remain 268 | let triggers: Vec = db 269 | .query( 270 | " 271 | SELECT trigger_name 272 | FROM information_schema.triggers 273 | WHERE trigger_schema = 'public' 274 | AND trigger_name LIKE '__reshape%' 275 | ", 276 | &[], 277 | ) 278 | .unwrap() 279 | .iter() 280 | .map(|row| row.get(0)) 281 | .collect(); 282 | 283 | assert!( 284 | triggers.is_empty(), 285 | "expected no triggers to exist, found: {}", 286 | triggers.join(", ") 287 | ); 288 | 289 | // Make sure no functions remain 290 | let functions: Vec = db 291 | .query( 292 | " 293 | SELECT routine_name 294 | FROM information_schema.routines 295 | WHERE routine_schema = 'public' 296 | AND routine_name LIKE '__reshape%' 297 | ", 298 | &[], 299 | ) 300 | .unwrap() 301 | .iter() 302 | .map(|row| row.get(0)) 303 | .collect(); 304 | 305 | assert!( 306 | functions.is_empty(), 307 | "expected no functions to exist, found: {}", 308 | functions.join(", ") 309 | ); 310 | } 311 | -------------------------------------------------------------------------------- /tests/complex.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn move_column_between_tables() { 6 | let mut test = Test::new("Move column between tables"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_tables" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions.columns]] 22 | name = "email" 23 | type = "TEXT" 24 | 25 | [[actions]] 26 | type = "create_table" 27 | name = "profiles" 28 | primary_key = ["id"] 29 | 30 | [[actions.columns]] 31 | name = "id" 32 | type = "INTEGER" 33 | 34 | [[actions.columns]] 35 | name = "user_id" 36 | type = "INTEGER" 37 | nullable = false 38 | "#, 39 | ); 40 | 41 | test.second_migration( 42 | r#" 43 | name = "move_email_column" 44 | 45 | [[actions]] 46 | type = "add_column" 47 | table = "profiles" 48 | 49 | [actions.column] 50 | name = "email" 51 | type = "TEXT" 52 | nullable = false 53 | 54 | # When `users` is updated in the old schema, we write the email value to `profiles` 55 | # When `profiles` is updated in the old schema, the equivalent `users.email` will also be updated 56 | [actions.up] 57 | table = "users" 58 | value = "users.email" 59 | where = "profiles.user_id = users.id" 60 | 61 | [[actions]] 62 | type = "remove_column" 63 | table = "users" 64 | column = "email" 65 | 66 | # When `profiles` is changed in the new schema, we write the email address back to the removed column 67 | [actions.down] 68 | table = "profiles" 69 | value = "profiles.email" 70 | where = "users.id = profiles.user_id" 71 | "#, 72 | ); 73 | 74 | test.after_first(|db| { 75 | db.simple_query("INSERT INTO users (id, email) VALUES (1, 'test1@test.com')") 76 | .unwrap(); 77 | db.simple_query("INSERT INTO users (id, email) VALUES (2, 'test2@test.com')") 78 | .unwrap(); 79 | 80 | db.simple_query("INSERT INTO profiles (id, user_id) VALUES (1, 1)") 81 | .unwrap(); 82 | db.simple_query("INSERT INTO profiles (id, user_id) VALUES (2, 2)") 83 | .unwrap(); 84 | }); 85 | 86 | test.intermediate(|old_db, new_db| { 87 | // Ensure emails were backfilled into profiles 88 | let profiles_emails: Vec = new_db 89 | .query( 90 | r#" 91 | SELECT email 92 | FROM profiles 93 | ORDER BY id 94 | "#, 95 | &[], 96 | ) 97 | .unwrap() 98 | .iter() 99 | .map(|row| row.get("email")) 100 | .collect(); 101 | assert_eq!(vec!("test1@test.com", "test2@test.com"), profiles_emails); 102 | 103 | // Ensure insert in old schema updates new 104 | old_db 105 | .simple_query("INSERT INTO users (id, email) VALUES (3, 'test3@test.com')") 106 | .unwrap(); 107 | old_db 108 | .simple_query("INSERT INTO profiles (id, user_id) VALUES (3, 3)") 109 | .unwrap(); 110 | let new_email: String = new_db 111 | .query("SELECT email FROM profiles WHERE id = 3", &[]) 112 | .unwrap() 113 | .first() 114 | .map(|row| row.get("email")) 115 | .unwrap(); 116 | assert_eq!("test3@test.com", new_email); 117 | 118 | // Ensure updates in old schema updates new 119 | old_db 120 | .simple_query("UPDATE users SET email = 'test3+updated@test.com' WHERE id = 3") 121 | .unwrap(); 122 | let new_email: String = new_db 123 | .query("SELECT email FROM profiles WHERE id = 3", &[]) 124 | .unwrap() 125 | .first() 126 | .map(|row| row.get("email")) 127 | .unwrap(); 128 | assert_eq!("test3+updated@test.com", new_email); 129 | 130 | // Ensure insert in new schema updates old 131 | new_db 132 | .simple_query( 133 | "INSERT INTO profiles (id, user_id, email) VALUES (4, 4, 'test4@test.com')", 134 | ) 135 | .unwrap(); 136 | new_db 137 | .simple_query("INSERT INTO users (id) VALUES (4)") 138 | .unwrap(); 139 | let old_email: String = old_db 140 | .query("SELECT email FROM users WHERE id = 4", &[]) 141 | .unwrap() 142 | .first() 143 | .map(|row| row.get("email")) 144 | .unwrap(); 145 | assert_eq!("test4@test.com", old_email); 146 | 147 | // Ensure update in new schema updates old 148 | new_db 149 | .simple_query("UPDATE profiles SET email = 'test4+updated@test.com' WHERE id = 4") 150 | .unwrap(); 151 | let old_email: String = old_db 152 | .query("SELECT email FROM users WHERE id = 4", &[]) 153 | .unwrap() 154 | .first() 155 | .map(|row| row.get("email")) 156 | .unwrap(); 157 | assert_eq!("test4+updated@test.com", old_email); 158 | }); 159 | 160 | test.run(); 161 | } 162 | 163 | #[test] 164 | fn extract_relation_into_new_table() { 165 | let mut test = Test::new("Extract relation into new table"); 166 | 167 | test.first_migration( 168 | r#" 169 | name = "create_tables" 170 | 171 | [[actions]] 172 | type = "create_table" 173 | name = "accounts" 174 | primary_key = ["id"] 175 | 176 | [[actions.columns]] 177 | name = "id" 178 | type = "INTEGER" 179 | 180 | [[actions]] 181 | type = "create_table" 182 | name = "users" 183 | primary_key = ["id"] 184 | 185 | [[actions.columns]] 186 | name = "id" 187 | type = "INTEGER" 188 | 189 | [[actions.columns]] 190 | name = "account_id" 191 | type = "INTEGER" 192 | nullable = false 193 | 194 | [[actions.columns]] 195 | name = "account_role" 196 | type = "TEXT" 197 | nullable = false 198 | "#, 199 | ); 200 | 201 | test.second_migration( 202 | r#" 203 | name = "add_account_user_connection" 204 | 205 | [[actions]] 206 | type = "create_table" 207 | name = "user_account_connections" 208 | primary_key = ["account_id", "user_id"] 209 | 210 | [[actions.columns]] 211 | name = "account_id" 212 | type = "INTEGER" 213 | 214 | [[actions.columns]] 215 | name = "user_id" 216 | type = "INTEGER" 217 | 218 | [[actions.columns]] 219 | name = "role" 220 | type = "TEXT" 221 | nullable = false 222 | 223 | [actions.up] 224 | table = "users" 225 | values = { user_id = "id", account_id = "account_id", role = "UPPER(account_role)" } 226 | where = "user_account_connections.user_id = users.id" 227 | 228 | [[actions]] 229 | type = "remove_column" 230 | table = "users" 231 | column = "account_id" 232 | 233 | [actions.down] 234 | table = "user_account_connections" 235 | value = "user_account_connections.account_id" 236 | where = "users.id = user_account_connections.user_id" 237 | 238 | [[actions]] 239 | type = "remove_column" 240 | table = "users" 241 | column = "account_role" 242 | 243 | [actions.down] 244 | table = "user_account_connections" 245 | value = "LOWER(user_account_connections.role)" 246 | where = "users.id = user_account_connections.user_id" 247 | "#, 248 | ); 249 | 250 | test.after_first(|db| { 251 | db.simple_query("INSERT INTO accounts (id) VALUES (1)") 252 | .unwrap(); 253 | db.simple_query("INSERT INTO users (id, account_id, account_role) VALUES (1, 1, 'admin')") 254 | .unwrap(); 255 | }); 256 | 257 | test.intermediate(|old_db, new_db| { 258 | // Ensure connections was backfilled 259 | let rows: Vec<(i32, i32, String)> = new_db 260 | .query( 261 | " 262 | SELECT account_id, user_id, role 263 | FROM user_account_connections 264 | ", 265 | &[], 266 | ) 267 | .unwrap() 268 | .iter() 269 | .map(|row| (row.get("account_id"), row.get("user_id"), row.get("role"))) 270 | .collect(); 271 | assert_eq!(1, rows.len()); 272 | 273 | let row = rows.first().unwrap(); 274 | assert_eq!(1, row.0); 275 | assert_eq!(1, row.1); 276 | assert_eq!("ADMIN", row.2); 277 | 278 | // Ensure inserted user in old schema creates a new connection 279 | old_db 280 | .simple_query( 281 | "INSERT INTO users (id, account_id, account_role) VALUES (2, 1, 'developer')", 282 | ) 283 | .unwrap(); 284 | assert!( 285 | new_db 286 | .query( 287 | " 288 | SELECT account_id, user_id, role 289 | FROM user_account_connections 290 | WHERE account_id = 1 AND user_id = 2 AND role = 'DEVELOPER' 291 | ", 292 | &[], 293 | ) 294 | .unwrap() 295 | .len() 296 | == 1 297 | ); 298 | 299 | // Ensure NOT NULL constraint still applies to old schema 300 | let result = old_db 301 | .simple_query( 302 | "INSERT INTO users (id, account_id, account_role) VALUES (2, NULL, 'developer')", 303 | ); 304 | assert!(result.is_err()); 305 | 306 | // Ensure updated user role in old schema updates connection in new schema 307 | old_db 308 | .simple_query("UPDATE users SET account_role = 'admin' WHERE id = 2") 309 | .unwrap(); 310 | assert!( 311 | new_db 312 | .query( 313 | " 314 | SELECT account_id, user_id, role 315 | FROM user_account_connections 316 | WHERE account_id = 1 AND user_id = 2 AND role = 'ADMIN' 317 | ", 318 | &[], 319 | ) 320 | .unwrap() 321 | .len() 322 | == 1 323 | ); 324 | 325 | // Ensure updated connection in new schema updates old schema user 326 | new_db 327 | .simple_query( 328 | "UPDATE user_account_connections SET role = 'DEVELOPER' WHERE account_id = 1 AND user_id = 2", 329 | ) 330 | .unwrap(); 331 | assert!( 332 | old_db 333 | .query( 334 | " 335 | SELECT id 336 | FROM users 337 | WHERE id = 2 AND account_id = 1 AND account_role = 'developer' 338 | ", 339 | &[], 340 | ) 341 | .unwrap() 342 | .len() 343 | == 1 344 | ); 345 | 346 | // Ensure insert of user with connection through new schema updates user in old schema 347 | new_db 348 | .simple_query( 349 | r#" 350 | BEGIN; 351 | INSERT INTO users (id) VALUES (3); 352 | INSERT INTO user_account_connections (user_id, account_id, role) VALUES (3, 1, 'DEVELOPER'); 353 | COMMIT; 354 | "#, 355 | ) 356 | .unwrap(); 357 | new_db 358 | .simple_query( 359 | "", 360 | ) 361 | .unwrap(); 362 | assert!( 363 | old_db 364 | .query( 365 | " 366 | SELECT id 367 | FROM users 368 | WHERE id = 3 AND account_id = 1 AND account_role = 'developer' 369 | ", 370 | &[], 371 | ) 372 | .unwrap() 373 | .len() 374 | == 1 375 | ); 376 | }); 377 | 378 | test.run(); 379 | } 380 | -------------------------------------------------------------------------------- /tests/create_enum.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn create_enum() { 6 | let mut test = Test::new("Create enum"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_enum_and_table" 11 | 12 | [[actions]] 13 | type = "create_enum" 14 | name = "mood" 15 | values = ["happy", "ok", "sad"] 16 | 17 | [[actions]] 18 | type = "create_table" 19 | name = "updates" 20 | primary_key = ["id"] 21 | 22 | [[actions.columns]] 23 | name = "id" 24 | type = "INTEGER" 25 | 26 | [[actions.columns]] 27 | name = "status" 28 | type = "mood" 29 | "#, 30 | ); 31 | 32 | test.after_first(|db| { 33 | // Valid enum values should succeed 34 | db.simple_query( 35 | "INSERT INTO updates (id, status) VALUES (1, 'happy'), (2, 'ok'), (3, 'sad')", 36 | ) 37 | .unwrap(); 38 | }); 39 | 40 | test.run(); 41 | } 42 | -------------------------------------------------------------------------------- /tests/create_table.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn create_table() { 6 | let mut test = Test::new("Create table"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_users_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | generated = "ALWAYS AS IDENTITY" 21 | 22 | [[actions.columns]] 23 | name = "name" 24 | type = "TEXT" 25 | 26 | [[actions.columns]] 27 | name = "created_at" 28 | type = "TIMESTAMP" 29 | nullable = false 30 | default = "NOW()" 31 | "#, 32 | ); 33 | 34 | test.after_first(|db| { 35 | // Ensure table was created 36 | let result = db 37 | .query_opt( 38 | " 39 | SELECT table_name 40 | FROM information_schema.tables 41 | WHERE table_name = 'users' AND table_schema = 'public'", 42 | &[], 43 | ) 44 | .unwrap(); 45 | assert!(result.is_some()); 46 | 47 | // Ensure table has the right columns 48 | let result = db 49 | .query( 50 | " 51 | SELECT column_name, column_default, is_nullable, data_type 52 | FROM information_schema.columns 53 | WHERE table_name = 'users' AND table_schema = 'public' 54 | ORDER BY ordinal_position", 55 | &[], 56 | ) 57 | .unwrap(); 58 | 59 | // id column 60 | let id_row = &result[0]; 61 | assert_eq!("id", id_row.get::<_, String>("column_name")); 62 | assert!(id_row.get::<_, Option>("column_default").is_none()); 63 | assert_eq!("NO", id_row.get::<_, String>("is_nullable")); 64 | assert_eq!("integer", id_row.get::<_, String>("data_type")); 65 | 66 | // name column 67 | let name_row = &result[1]; 68 | assert_eq!("name", name_row.get::<_, String>("column_name")); 69 | assert!(name_row 70 | .get::<_, Option>("column_default") 71 | .is_none()); 72 | assert_eq!("YES", name_row.get::<_, String>("is_nullable")); 73 | assert_eq!("text", name_row.get::<_, String>("data_type")); 74 | 75 | // created_at column 76 | let created_at_column = &result[2]; 77 | assert_eq!( 78 | "created_at", 79 | created_at_column.get::<_, String>("column_name") 80 | ); 81 | assert!(created_at_column 82 | .get::<_, Option>("column_default") 83 | .is_some()); 84 | assert_eq!("NO", created_at_column.get::<_, String>("is_nullable")); 85 | assert_eq!( 86 | "timestamp without time zone", 87 | created_at_column.get::<_, String>("data_type") 88 | ); 89 | 90 | // Ensure the primary key has the right columns 91 | let primary_key_columns: Vec = db 92 | .query( 93 | " 94 | SELECT a.attname AS column 95 | FROM pg_index i 96 | JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) 97 | JOIN pg_class t ON t.oid = i.indrelid 98 | WHERE t.relname = 'users' AND i.indisprimary 99 | ", 100 | &[], 101 | ) 102 | .unwrap() 103 | .iter() 104 | .map(|row| row.get("column")) 105 | .collect(); 106 | 107 | assert_eq!(vec!["id"], primary_key_columns); 108 | }); 109 | 110 | test.run(); 111 | } 112 | 113 | #[test] 114 | fn create_table_with_foreign_keys() { 115 | let mut test = Test::new("Create table"); 116 | 117 | test.first_migration( 118 | r#" 119 | name = "create_users_table" 120 | 121 | [[actions]] 122 | type = "create_table" 123 | name = "users" 124 | primary_key = ["id"] 125 | 126 | [[actions.columns]] 127 | name = "id" 128 | type = "INTEGER" 129 | generated = "ALWAYS AS IDENTITY" 130 | 131 | [[actions.columns]] 132 | name = "name" 133 | type = "TEXT" 134 | 135 | [[actions.columns]] 136 | name = "created_at" 137 | type = "TIMESTAMP" 138 | nullable = false 139 | default = "NOW()" 140 | 141 | [[actions]] 142 | type = "create_table" 143 | name = "items" 144 | primary_key = ["id"] 145 | 146 | [[actions.columns]] 147 | name = "id" 148 | type = "INTEGER" 149 | 150 | [[actions.columns]] 151 | name = "user_id" 152 | type = "INTEGER" 153 | nullable = false 154 | 155 | [[actions.foreign_keys]] 156 | columns = ["user_id"] 157 | referenced_table = "users" 158 | referenced_columns = ["id"] 159 | "#, 160 | ); 161 | 162 | test.after_first(|db| { 163 | let foreign_key_columns: Vec<(String, String, String)> = db 164 | .query( 165 | " 166 | SELECT 167 | kcu.column_name, 168 | ccu.table_name AS foreign_table_name, 169 | ccu.column_name AS foreign_column_name 170 | FROM 171 | information_schema.table_constraints AS tc 172 | JOIN information_schema.key_column_usage AS kcu 173 | ON tc.constraint_name = kcu.constraint_name 174 | AND tc.table_schema = kcu.table_schema 175 | JOIN information_schema.constraint_column_usage AS ccu 176 | ON ccu.constraint_name = tc.constraint_name 177 | AND ccu.table_schema = tc.table_schema 178 | WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='items'; 179 | ", 180 | &[], 181 | ) 182 | .unwrap() 183 | .iter() 184 | .map(|row| { 185 | ( 186 | row.get("column_name"), 187 | row.get("foreign_table_name"), 188 | row.get("foreign_column_name"), 189 | ) 190 | }) 191 | .collect(); 192 | 193 | assert_eq!( 194 | vec![("user_id".to_string(), "users".to_string(), "id".to_string())], 195 | foreign_key_columns 196 | ); 197 | }); 198 | 199 | test.run(); 200 | } 201 | -------------------------------------------------------------------------------- /tests/custom.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn custom_enable_extension() { 6 | let mut test = Test::new("Custom migration"); 7 | 8 | test.clear(|db| { 9 | db.simple_query( 10 | " 11 | DROP EXTENSION IF EXISTS bloom; 12 | DROP EXTENSION IF EXISTS btree_gin; 13 | DROP EXTENSION IF EXISTS btree_gist; 14 | ", 15 | ) 16 | .unwrap(); 17 | }); 18 | 19 | test.first_migration( 20 | r#" 21 | name = "empty_migration" 22 | 23 | [[actions]] 24 | type = "custom" 25 | "#, 26 | ); 27 | 28 | test.second_migration( 29 | r#" 30 | name = "enable_extensions" 31 | 32 | [[actions]] 33 | type = "custom" 34 | 35 | start = """ 36 | CREATE EXTENSION IF NOT EXISTS bloom; 37 | CREATE EXTENSION IF NOT EXISTS btree_gin; 38 | """ 39 | 40 | complete = "CREATE EXTENSION IF NOT EXISTS btree_gist" 41 | 42 | abort = """ 43 | DROP EXTENSION IF EXISTS bloom; 44 | DROP EXTENSION IF EXISTS btree_gin; 45 | """ 46 | "#, 47 | ); 48 | 49 | test.intermediate(|db, _| { 50 | let bloom_activated = !db 51 | .query("SELECT * FROM pg_extension WHERE extname = 'bloom'", &[]) 52 | .unwrap() 53 | .is_empty(); 54 | assert!(bloom_activated); 55 | 56 | let btree_gin_activated = !db 57 | .query( 58 | "SELECT * FROM pg_extension WHERE extname = 'btree_gin'", 59 | &[], 60 | ) 61 | .unwrap() 62 | .is_empty(); 63 | assert!(btree_gin_activated); 64 | 65 | let btree_gist_activated = !db 66 | .query( 67 | "SELECT * FROM pg_extension WHERE extname = 'btree_gist'", 68 | &[], 69 | ) 70 | .unwrap() 71 | .is_empty(); 72 | assert!(!btree_gist_activated); 73 | }); 74 | 75 | test.after_completion(|db| { 76 | let bloom_activated = !db 77 | .query("SELECT * FROM pg_extension WHERE extname = 'bloom'", &[]) 78 | .unwrap() 79 | .is_empty(); 80 | assert!(bloom_activated); 81 | 82 | let btree_gin_activated = !db 83 | .query( 84 | "SELECT * FROM pg_extension WHERE extname = 'btree_gin'", 85 | &[], 86 | ) 87 | .unwrap() 88 | .is_empty(); 89 | assert!(btree_gin_activated); 90 | 91 | let btree_gist_activated = !db 92 | .query( 93 | "SELECT * FROM pg_extension WHERE extname = 'btree_gist'", 94 | &[], 95 | ) 96 | .unwrap() 97 | .is_empty(); 98 | assert!(btree_gist_activated); 99 | }); 100 | 101 | test.after_abort(|db| { 102 | let bloom_activated = !db 103 | .query("SELECT * FROM pg_extension WHERE extname = 'bloom'", &[]) 104 | .unwrap() 105 | .is_empty(); 106 | assert!(!bloom_activated); 107 | 108 | let btree_gin_activated = !db 109 | .query( 110 | "SELECT * FROM pg_extension WHERE extname = 'btree_gin'", 111 | &[], 112 | ) 113 | .unwrap() 114 | .is_empty(); 115 | assert!(!btree_gin_activated); 116 | 117 | let btree_gist_activated = !db 118 | .query( 119 | "SELECT * FROM pg_extension WHERE extname = 'btree_gist'", 120 | &[], 121 | ) 122 | .unwrap() 123 | .is_empty(); 124 | assert!(!btree_gist_activated); 125 | }); 126 | 127 | test.run(); 128 | } 129 | -------------------------------------------------------------------------------- /tests/failure.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn invalid_migration() { 6 | let mut test = Test::new("Invalid migration"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "invalid_migration" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | "#, 21 | ); 22 | 23 | test.second_migration( 24 | r#" 25 | name = "add_invalid_column" 26 | 27 | [[actions]] 28 | type = "add_column" 29 | table = "users" 30 | 31 | up = "INVALID SQL" 32 | 33 | [actions.column] 34 | name = "first" 35 | type = "TEXT" 36 | "#, 37 | ); 38 | 39 | // Insert a test user 40 | test.after_first(|db| { 41 | db.simple_query( 42 | " 43 | INSERT INTO users (id) VALUES (1) 44 | ", 45 | ) 46 | .unwrap(); 47 | }); 48 | 49 | test.expect_failure(); 50 | test.run(); 51 | } 52 | -------------------------------------------------------------------------------- /tests/remove_column.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn remove_column() { 6 | let mut test = Test::new("Remove column"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_user_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions.columns]] 22 | name = "name" 23 | type = "TEXT" 24 | "#, 25 | ); 26 | 27 | test.second_migration( 28 | r#" 29 | name = "remove_name_column" 30 | 31 | [[actions]] 32 | type = "remove_column" 33 | table = "users" 34 | column = "name" 35 | down = "'TEST_DOWN_VALUE'" 36 | "#, 37 | ); 38 | 39 | test.intermediate(|old_db, new_db| { 40 | // Insert using old schema and ensure it can be retrieved through new schema 41 | old_db 42 | .simple_query("INSERT INTO users(id, name) VALUES (1, 'John Doe')") 43 | .unwrap(); 44 | let results = new_db 45 | .query("SELECT id FROM users WHERE id = 1", &[]) 46 | .unwrap(); 47 | assert_eq!(1, results.len()); 48 | assert_eq!(1, results[0].get::<_, i32>("id")); 49 | 50 | // Ensure the name column is not accesible through the new schema 51 | assert!(new_db.query("SELECT id, name FROM users", &[]).is_err()); 52 | 53 | // Insert using new schema and ensure the down function is correctly applied 54 | new_db 55 | .simple_query("INSERT INTO users(id) VALUES (2)") 56 | .unwrap(); 57 | let result = old_db 58 | .query_opt("SELECT name FROM users WHERE id = 2", &[]) 59 | .unwrap(); 60 | assert_eq!( 61 | Some("TEST_DOWN_VALUE"), 62 | result.as_ref().map(|row| row.get("name")) 63 | ); 64 | }); 65 | 66 | test.run(); 67 | } 68 | 69 | #[test] 70 | fn remove_column_with_index() { 71 | let mut test = Test::new("Remove column"); 72 | 73 | test.first_migration( 74 | r#" 75 | name = "create_user_table" 76 | 77 | [[actions]] 78 | type = "create_table" 79 | name = "users" 80 | primary_key = ["id"] 81 | 82 | [[actions.columns]] 83 | name = "id" 84 | type = "INTEGER" 85 | 86 | [[actions.columns]] 87 | name = "name" 88 | type = "TEXT" 89 | 90 | [[actions]] 91 | type = "add_index" 92 | table = "users" 93 | 94 | [actions.index] 95 | name = "name_idx" 96 | columns = ["name"] 97 | "#, 98 | ); 99 | 100 | test.second_migration( 101 | r#" 102 | name = "remove_name_column" 103 | 104 | [[actions]] 105 | type = "remove_column" 106 | table = "users" 107 | column = "name" 108 | down = "'TEST_DOWN_VALUE'" 109 | "#, 110 | ); 111 | 112 | test.after_completion(|db| { 113 | // Ensure index has been removed after the migration is complete 114 | let count: i64 = db 115 | .query( 116 | " 117 | SELECT COUNT(*) 118 | FROM pg_catalog.pg_index 119 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 120 | WHERE pg_class.relname = 'name_idx' 121 | ", 122 | &[], 123 | ) 124 | .unwrap() 125 | .first() 126 | .map(|row| row.get(0)) 127 | .unwrap(); 128 | 129 | assert_eq!(0, count, "expected index to not exist"); 130 | }); 131 | 132 | test.run(); 133 | } 134 | 135 | #[test] 136 | fn remove_column_with_complex_down() { 137 | let mut test = Test::new("Remove column complex"); 138 | 139 | test.first_migration( 140 | r#" 141 | name = "create_tables" 142 | 143 | [[actions]] 144 | type = "create_table" 145 | name = "users" 146 | primary_key = ["id"] 147 | 148 | [[actions.columns]] 149 | name = "id" 150 | type = "INTEGER" 151 | 152 | [[actions.columns]] 153 | name = "email" 154 | type = "TEXT" 155 | 156 | [[actions]] 157 | type = "create_table" 158 | name = "profiles" 159 | primary_key = ["user_id"] 160 | 161 | [[actions.columns]] 162 | name = "user_id" 163 | type = "INTEGER" 164 | 165 | [[actions.columns]] 166 | name = "email" 167 | type = "TEXT" 168 | "#, 169 | ); 170 | 171 | test.second_migration( 172 | r#" 173 | name = "remove_users_email_column" 174 | 175 | [[actions]] 176 | type = "remove_column" 177 | table = "users" 178 | column = "email" 179 | 180 | [actions.down] 181 | table = "profiles" 182 | value = "profiles.email" 183 | where = "users.id = profiles.user_id" 184 | "#, 185 | ); 186 | 187 | test.after_first(|db| { 188 | db.simple_query("INSERT INTO users (id, email) VALUES (1, 'test@example.com')") 189 | .unwrap(); 190 | db.simple_query("INSERT INTO profiles (user_id, email) VALUES (1, 'test@example.com')") 191 | .unwrap(); 192 | }); 193 | 194 | test.intermediate(|old_db, new_db| { 195 | new_db 196 | .simple_query("UPDATE profiles SET email = 'test2@example.com' WHERE user_id = 1") 197 | .unwrap(); 198 | 199 | // Ensure new email was propagated to users table in old schema 200 | let email: String = old_db 201 | .query( 202 | " 203 | SELECT email 204 | FROM users 205 | WHERE id = 1 206 | ", 207 | &[], 208 | ) 209 | .unwrap() 210 | .first() 211 | .map(|row| row.get("email")) 212 | .unwrap(); 213 | assert_eq!("test2@example.com", email); 214 | }); 215 | 216 | test.run(); 217 | } 218 | -------------------------------------------------------------------------------- /tests/remove_enum.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn remove_enum() { 6 | let mut test = Test::new("Remove enum"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_enum" 11 | 12 | [[actions]] 13 | type = "create_enum" 14 | name = "mood" 15 | values = ["happy", "ok", "sad"] 16 | "#, 17 | ); 18 | 19 | test.second_migration( 20 | r#" 21 | name = "remove_enum" 22 | 23 | [[actions]] 24 | type = "remove_enum" 25 | enum = "mood" 26 | "#, 27 | ); 28 | 29 | test.after_first(|db| { 30 | // Ensure enum was created 31 | let enum_exists = !db 32 | .query( 33 | " 34 | SELECT typname 35 | FROM pg_catalog.pg_type 36 | WHERE typcategory = 'E' 37 | AND typname = 'mood' 38 | ", 39 | &[], 40 | ) 41 | .unwrap() 42 | .is_empty(); 43 | 44 | assert!(enum_exists, "expected mood enum to have been created"); 45 | }); 46 | 47 | test.after_completion(|db| { 48 | // Ensure enum was removed after completion 49 | let enum_does_not_exist = db 50 | .query( 51 | " 52 | SELECT typname 53 | FROM pg_catalog.pg_type 54 | WHERE typcategory = 'E' 55 | AND typname = 'mood' 56 | ", 57 | &[], 58 | ) 59 | .unwrap() 60 | .is_empty(); 61 | 62 | assert!( 63 | enum_does_not_exist, 64 | "expected mood enum to have been removed" 65 | ); 66 | }); 67 | 68 | test.run(); 69 | } 70 | -------------------------------------------------------------------------------- /tests/remove_foreign_key.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn remove_foreign_key() { 6 | let mut test = Test::new("Remove foreign key"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_tables" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions]] 22 | type = "create_table" 23 | name = "items" 24 | primary_key = ["id"] 25 | 26 | [[actions.columns]] 27 | name = "id" 28 | type = "INTEGER" 29 | 30 | [[actions.columns]] 31 | name = "user_id" 32 | type = "INTEGER" 33 | 34 | [[actions.foreign_keys]] 35 | columns = ["user_id"] 36 | referenced_table = "users" 37 | referenced_columns = ["id"] 38 | "#, 39 | ); 40 | 41 | test.second_migration( 42 | r#" 43 | name = "remove_foreign_key" 44 | 45 | [[actions]] 46 | type = "remove_foreign_key" 47 | table = "items" 48 | foreign_key = "items_user_id_fkey" 49 | "#, 50 | ); 51 | 52 | test.after_first(|db| { 53 | // Insert some test users 54 | db.simple_query("INSERT INTO users (id) VALUES (1), (2)") 55 | .unwrap(); 56 | }); 57 | 58 | test.intermediate(|old_db, new_db| { 59 | // Ensure items can't be inserted if they don't reference valid users 60 | // The foreign key is only removed when the migration is completed so 61 | // it should still be enforced for the new and old schema. 62 | let result = old_db.simple_query("INSERT INTO items (id, user_id) VALUES (3, 3)"); 63 | assert!( 64 | result.is_err(), 65 | "expected insert against old schema to fail" 66 | ); 67 | 68 | let result = new_db.simple_query("INSERT INTO items (id, user_id) VALUES (3, 3)"); 69 | assert!( 70 | result.is_err(), 71 | "expected insert against new schema to fail" 72 | ); 73 | }); 74 | 75 | test.after_completion(|db| { 76 | // Ensure items can be inserted even if they don't reference valid users 77 | let result = db 78 | .simple_query("INSERT INTO items (id, user_id) VALUES (5, 3)") 79 | .map(|_| ()); 80 | assert!( 81 | result.is_ok(), 82 | "expected insert to not fail, got {:?}", 83 | result 84 | ); 85 | 86 | // Ensure foreign key doesn't exist 87 | let foreign_keys = db 88 | .query( 89 | " 90 | SELECT tc.constraint_name 91 | FROM information_schema.table_constraints AS tc 92 | WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='items'; 93 | ", 94 | &[], 95 | ) 96 | .unwrap(); 97 | assert!( 98 | foreign_keys.is_empty(), 99 | "expected no foreign keys to exist on items table" 100 | ); 101 | }); 102 | 103 | test.after_abort(|db| { 104 | // Ensure items can't be inserted if they don't reference valid users 105 | let result = db.simple_query("INSERT INTO items (id, user_id) VALUES (3, 3)"); 106 | assert!(result.is_err(), "expected insert to fail"); 107 | 108 | // Ensure foreign key still exists 109 | let fk_exists = !db 110 | .query( 111 | " 112 | SELECT tc.constraint_name 113 | FROM information_schema.table_constraints AS tc 114 | WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='items'; 115 | ", 116 | &[], 117 | ) 118 | .unwrap() 119 | .is_empty(); 120 | assert!(fk_exists); 121 | }); 122 | 123 | test.run() 124 | } 125 | -------------------------------------------------------------------------------- /tests/remove_index.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn remove_index() { 6 | let mut test = Test::new("Remove index"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_users_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | 21 | [[actions.columns]] 22 | name = "name" 23 | type = "TEXT" 24 | 25 | [[actions]] 26 | type = "add_index" 27 | table = "users" 28 | 29 | [actions.index] 30 | name = "name_idx" 31 | columns = ["name"] 32 | "#, 33 | ); 34 | 35 | test.second_migration( 36 | r#" 37 | name = "remove_name_index" 38 | 39 | [[actions]] 40 | type = "remove_index" 41 | index = "name_idx" 42 | "#, 43 | ); 44 | 45 | test.intermediate(|db, _| { 46 | // Ensure index is still valid and ready during the migration 47 | let result: Vec<(bool, bool)> = db 48 | .query( 49 | " 50 | SELECT pg_index.indisready, pg_index.indisvalid 51 | FROM pg_catalog.pg_index 52 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 53 | WHERE pg_class.relname = 'name_idx' 54 | ", 55 | &[], 56 | ) 57 | .unwrap() 58 | .iter() 59 | .map(|row| (row.get("indisready"), row.get("indisvalid"))) 60 | .collect(); 61 | 62 | assert_eq!(vec![(true, true)], result); 63 | }); 64 | 65 | test.after_completion(|db| { 66 | // Ensure index has been removed after the migration is complete 67 | let count: i64 = db 68 | .query( 69 | " 70 | SELECT COUNT(*) 71 | FROM pg_catalog.pg_index 72 | JOIN pg_catalog.pg_class ON pg_index.indexrelid = pg_class.oid 73 | WHERE pg_class.relname = 'name_idx' 74 | ", 75 | &[], 76 | ) 77 | .unwrap() 78 | .first() 79 | .map(|row| row.get(0)) 80 | .unwrap(); 81 | 82 | assert_eq!(0, count, "expected index to not exist"); 83 | }); 84 | 85 | test.run(); 86 | } 87 | -------------------------------------------------------------------------------- /tests/remove_table.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn remove_table() { 6 | let mut test = Test::new("Remove table"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_users_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | "#, 21 | ); 22 | 23 | test.second_migration( 24 | r#" 25 | name = "remove_users_table" 26 | 27 | [[actions]] 28 | type = "remove_table" 29 | table = "users" 30 | "#, 31 | ); 32 | 33 | test.intermediate(|old_db, new_db| { 34 | // Make sure inserts work against the old schema 35 | old_db 36 | .simple_query("INSERT INTO users(id) VALUES (1)") 37 | .unwrap(); 38 | 39 | // Ensure the table is not accessible through the new schema 40 | assert!(new_db.query("SELECT id FROM users", &[]).is_err()); 41 | }); 42 | 43 | test.run(); 44 | } 45 | -------------------------------------------------------------------------------- /tests/rename_table.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::Test; 3 | 4 | #[test] 5 | fn rename_table() { 6 | let mut test = Test::new("Rename table"); 7 | 8 | test.first_migration( 9 | r#" 10 | name = "create_users_table" 11 | 12 | [[actions]] 13 | type = "create_table" 14 | name = "users" 15 | primary_key = ["id"] 16 | 17 | [[actions.columns]] 18 | name = "id" 19 | type = "INTEGER" 20 | "#, 21 | ); 22 | 23 | test.second_migration( 24 | r#" 25 | name = "rename_users_table_to_customers" 26 | 27 | [[actions]] 28 | type = "rename_table" 29 | table = "users" 30 | new_name = "customers" 31 | "#, 32 | ); 33 | 34 | test.intermediate(|old_db, new_db| { 35 | // Make sure inserts work using both the old and new name 36 | old_db 37 | .simple_query("INSERT INTO users(id) VALUES (1)") 38 | .unwrap(); 39 | new_db 40 | .simple_query("INSERT INTO customers(id) VALUES (2)") 41 | .unwrap(); 42 | 43 | // Ensure the table can be queried using both the old and new name 44 | let expected: Vec = vec![1, 2]; 45 | assert_eq!( 46 | expected, 47 | old_db 48 | .query("SELECT id FROM users ORDER BY id", &[]) 49 | .unwrap() 50 | .iter() 51 | .map(|row| row.get::<_, i32>("id")) 52 | .collect::>() 53 | ); 54 | assert_eq!( 55 | expected, 56 | new_db 57 | .query("SELECT id FROM customers ORDER BY id", &[]) 58 | .unwrap() 59 | .iter() 60 | .map(|row| row.get::<_, i32>("id")) 61 | .collect::>() 62 | ); 63 | 64 | // Ensure the table can't be queried using the wrong name for the schema 65 | assert!(old_db.simple_query("SELECT id FROM customers").is_err()); 66 | assert!(new_db.simple_query("SELECT id FROM users").is_err()); 67 | }); 68 | 69 | test.run(); 70 | } 71 | --------------------------------------------------------------------------------