├── .gitignore ├── rustfmt.toml ├── Cargo.toml ├── convergence-arrow ├── src │ ├── lib.rs │ ├── datafusion.rs │ └── table.rs ├── data │ ├── generate.R │ └── 100_4buckets.csv ├── Cargo.toml └── tests │ ├── test_datafusion.rs │ └── test_arrow.rs ├── .github ├── dependabot.yml └── workflows │ └── test.yml ├── convergence ├── src │ ├── lib.rs │ ├── engine.rs │ ├── server.rs │ ├── protocol_ext.rs │ ├── connection.rs │ └── protocol.rs ├── Cargo.toml └── tests │ └── test_connection.rs ├── readme.md └── licence /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | hard_tabs = true 2 | max_width = 120 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "convergence", 4 | "convergence-arrow", 5 | ] 6 | -------------------------------------------------------------------------------- /convergence-arrow/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Utils for bridging Apache Arrow and PostgreSQL's wire protocol. 2 | 3 | #![warn(missing_docs)] 4 | 5 | pub mod datafusion; 6 | pub mod table; 7 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: cargo 4 | directory: / 5 | schedule: 6 | interval: daily 7 | - package-ecosystem: github-actions 8 | directory: / 9 | schedule: 10 | interval: daily 11 | -------------------------------------------------------------------------------- /convergence/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Convergence is a crate for writing servers that speak PostgreSQL's wire protocol. 2 | 3 | #![warn(missing_docs)] 4 | 5 | pub mod connection; 6 | pub mod engine; 7 | pub mod protocol; 8 | pub mod protocol_ext; 9 | pub mod server; 10 | -------------------------------------------------------------------------------- /convergence-arrow/data/generate.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | csv <- function(file, ...) { 4 | write.csv(data.frame(...), paste0("convergence-arrow/data/", file, ".csv"), row.names = F) 5 | } 6 | 7 | csv("100_4buckets", 8 | id = seq(1, 100), 9 | bucket = c("a", "b", "c", "d") 10 | ) 11 | -------------------------------------------------------------------------------- /convergence-arrow/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "convergence-arrow" 3 | version = "0.5.0" 4 | authors = ["Ruan Pearce-Authers "] 5 | edition = "2018" 6 | description = "Utils for bridging Apache Arrow and PostgreSQL's wire protocol" 7 | license = "MIT" 8 | repository = "https://github.com/returnString/convergence" 9 | 10 | [dependencies] 11 | tokio = { version = "1" } 12 | sqlparser = "0.18" 13 | async-trait = "0.1" 14 | datafusion = "10" 15 | convergence = { path = "../convergence", version = "0.5.0" } 16 | chrono = "0.4" 17 | 18 | [dev-dependencies] 19 | tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4" ] } 20 | -------------------------------------------------------------------------------- /convergence/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "convergence" 3 | version = "0.5.0" 4 | authors = ["Ruan Pearce-Authers "] 5 | edition = "2018" 6 | description = "Write servers that speak PostgreSQL's wire protocol" 7 | license = "MIT" 8 | repository = "https://github.com/returnString/convergence" 9 | 10 | [dependencies] 11 | tokio = { version = "1", features = [ "net", "rt-multi-thread", "macros", "io-util", "io-std" ] } 12 | tokio-util = { version = "0.7", features = [ "codec" ] } 13 | thiserror = "1" 14 | bytes = "1" 15 | futures = "0.3" 16 | sqlparser = "0.18" 17 | async-trait = "0.1" 18 | chrono = "0.4" 19 | 20 | [dev-dependencies] 21 | tokio-postgres = "0.7" 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Convergence 2 | ![Tests](https://github.com/returnString/convergence/workflows/Test/badge.svg) ![Crates.io](https://img.shields.io/crates/v/convergence) 3 | 4 | A set of tools for writing servers that speak PostgreSQL's wire protocol. 5 | 6 | 🚧 This project is _extremely_ WIP at this stage. 7 | 8 | ## Crates 9 | `convergence` contains the core traits, protocol handling and connection state machine for emulating a Postgres server. 10 | 11 | `convergence-arrow` enables translation of [Apache Arrow](https://arrow.apache.org) dataframes into Postgres result sets, allowing you to access your Arrow-powered data services via standard Postgres drivers. It also provides a reusable `Engine` implementation using [DataFusion](https://github.com/apache/arrow-datafusion) for execution. 12 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | os: 13 | - ubuntu-20.04 14 | toolchain: 15 | - 1.51.0 16 | 17 | runs-on: ${{ matrix.os }} 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v3.0.2 22 | 23 | - name: Install toolchain 24 | uses: actions-rs/toolchain@v1 25 | with: 26 | toolchain: ${{ matrix.toolchain }} 27 | 28 | - name: Run tests 29 | uses: actions-rs/cargo@v1.0.3 30 | with: 31 | command: test 32 | 33 | - name: Run clippy 34 | uses: actions-rs/cargo@v1.0.3 35 | with: 36 | command: clippy 37 | args: -- -D warnings 38 | -------------------------------------------------------------------------------- /licence: -------------------------------------------------------------------------------- 1 | MIT Licence 2 | 3 | Copyright (c) 2021 Ruan Pearce-Authers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /convergence-arrow/data/100_4buckets.csv: -------------------------------------------------------------------------------- 1 | "id","bucket" 2 | 1,"a" 3 | 2,"b" 4 | 3,"c" 5 | 4,"d" 6 | 5,"a" 7 | 6,"b" 8 | 7,"c" 9 | 8,"d" 10 | 9,"a" 11 | 10,"b" 12 | 11,"c" 13 | 12,"d" 14 | 13,"a" 15 | 14,"b" 16 | 15,"c" 17 | 16,"d" 18 | 17,"a" 19 | 18,"b" 20 | 19,"c" 21 | 20,"d" 22 | 21,"a" 23 | 22,"b" 24 | 23,"c" 25 | 24,"d" 26 | 25,"a" 27 | 26,"b" 28 | 27,"c" 29 | 28,"d" 30 | 29,"a" 31 | 30,"b" 32 | 31,"c" 33 | 32,"d" 34 | 33,"a" 35 | 34,"b" 36 | 35,"c" 37 | 36,"d" 38 | 37,"a" 39 | 38,"b" 40 | 39,"c" 41 | 40,"d" 42 | 41,"a" 43 | 42,"b" 44 | 43,"c" 45 | 44,"d" 46 | 45,"a" 47 | 46,"b" 48 | 47,"c" 49 | 48,"d" 50 | 49,"a" 51 | 50,"b" 52 | 51,"c" 53 | 52,"d" 54 | 53,"a" 55 | 54,"b" 56 | 55,"c" 57 | 56,"d" 58 | 57,"a" 59 | 58,"b" 60 | 59,"c" 61 | 60,"d" 62 | 61,"a" 63 | 62,"b" 64 | 63,"c" 65 | 64,"d" 66 | 65,"a" 67 | 66,"b" 68 | 67,"c" 69 | 68,"d" 70 | 69,"a" 71 | 70,"b" 72 | 71,"c" 73 | 72,"d" 74 | 73,"a" 75 | 74,"b" 76 | 75,"c" 77 | 76,"d" 78 | 77,"a" 79 | 78,"b" 80 | 79,"c" 81 | 80,"d" 82 | 81,"a" 83 | 82,"b" 84 | 83,"c" 85 | 84,"d" 86 | 85,"a" 87 | 86,"b" 88 | 87,"c" 89 | 88,"d" 90 | 89,"a" 91 | 90,"b" 92 | 91,"c" 93 | 92,"d" 94 | 93,"a" 95 | 94,"b" 96 | 95,"c" 97 | 96,"d" 98 | 97,"a" 99 | 98,"b" 100 | 99,"c" 101 | 100,"d" 102 | -------------------------------------------------------------------------------- /convergence/src/engine.rs: -------------------------------------------------------------------------------- 1 | //! Contains core interface definitions for custom SQL engines. 2 | 3 | use crate::protocol::{ErrorResponse, FieldDescription}; 4 | use crate::protocol_ext::DataRowBatch; 5 | use async_trait::async_trait; 6 | use sqlparser::ast::Statement; 7 | 8 | /// A Postgres portal. Portals represent a prepared statement with all parameters specified. 9 | /// 10 | /// See Postgres' protocol docs regarding the [extended query overview](https://www.postgresql.org/docs/current/protocol-overview.html#PROTOCOL-QUERY-CONCEPTS) 11 | /// for more details. 12 | #[async_trait] 13 | pub trait Portal: Send + Sync { 14 | /// Fetches the contents of the portal into a [DataRowBatch]. 15 | async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse>; 16 | } 17 | 18 | /// The engine trait is the core of the `convergence` crate, and is responsible for dispatching most SQL operations. 19 | /// 20 | /// Each connection is allocated an [Engine] instance, which it uses to prepare statements, create portals, etc. 21 | #[async_trait] 22 | pub trait Engine: Send + Sync + 'static { 23 | /// The [Portal] implementation used by [Engine::create_portal]. 24 | type PortalType: Portal; 25 | 26 | /// Prepares a statement, returning a vector of field descriptions for the final statement result. 27 | async fn prepare(&mut self, stmt: &Statement) -> Result, ErrorResponse>; 28 | 29 | /// Creates a new portal for the given statement. 30 | async fn create_portal(&mut self, stmt: &Statement) -> Result; 31 | } 32 | -------------------------------------------------------------------------------- /convergence-arrow/tests/test_datafusion.rs: -------------------------------------------------------------------------------- 1 | use convergence::server::{self, BindOptions}; 2 | use convergence_arrow::datafusion::DataFusionEngine; 3 | use datafusion::prelude::*; 4 | use std::sync::Arc; 5 | use tokio_postgres::{connect, NoTls}; 6 | 7 | async fn new_engine() -> DataFusionEngine { 8 | let ctx = SessionContext::new(); 9 | ctx.register_csv("test_100_4buckets", "data/100_4buckets.csv", CsvReadOptions::new()) 10 | .await 11 | .expect("failed to register csv"); 12 | 13 | DataFusionEngine::new(ctx) 14 | } 15 | 16 | async fn setup() -> tokio_postgres::Client { 17 | let port = server::run_background(BindOptions::new().with_port(0), Arc::new(|| Box::pin(new_engine()))) 18 | .await 19 | .unwrap(); 20 | 21 | let (client, conn) = connect(&format!("postgres://localhost:{}/test", port), NoTls) 22 | .await 23 | .expect("failed to init client"); 24 | 25 | tokio::spawn(async move { conn.await.unwrap() }); 26 | 27 | client 28 | } 29 | 30 | #[tokio::test] 31 | async fn count_rows() { 32 | let client = setup().await; 33 | 34 | let row = client 35 | .query_one("select count(*) from test_100_4buckets", &[]) 36 | .await 37 | .unwrap(); 38 | 39 | let count: i64 = row.get(0); 40 | assert_eq!(count, 100); 41 | } 42 | 43 | #[tokio::test] 44 | async fn grouped_counts() { 45 | let client = setup().await; 46 | 47 | let rows = client 48 | .query( 49 | "select bucket, count(*) from test_100_4buckets group by bucket order by bucket", 50 | &[], 51 | ) 52 | .await 53 | .unwrap(); 54 | 55 | assert_eq!(rows.len(), 4); 56 | 57 | let get_row = |idx: usize| { 58 | let row = &rows[idx]; 59 | let cols: (&str, i64) = (row.get(0), row.get(1)); 60 | cols 61 | }; 62 | 63 | assert_eq!(get_row(0), ("a", 25)); 64 | assert_eq!(get_row(1), ("b", 25)); 65 | assert_eq!(get_row(2), ("c", 25)); 66 | assert_eq!(get_row(3), ("d", 25)); 67 | } 68 | -------------------------------------------------------------------------------- /convergence-arrow/src/datafusion.rs: -------------------------------------------------------------------------------- 1 | //! Provides a DataFusion-powered implementation of the [Engine] trait. 2 | 3 | use crate::table::{record_batch_to_rows, schema_to_field_desc}; 4 | use async_trait::async_trait; 5 | use convergence::engine::{Engine, Portal}; 6 | use convergence::protocol::{ErrorResponse, FieldDescription, SqlState}; 7 | use convergence::protocol_ext::DataRowBatch; 8 | use datafusion::error::DataFusionError; 9 | use datafusion::prelude::*; 10 | use sqlparser::ast::Statement; 11 | use std::sync::Arc; 12 | 13 | fn df_err_to_sql(err: DataFusionError) -> ErrorResponse { 14 | ErrorResponse::error(SqlState::DATA_EXCEPTION, err.to_string()) 15 | } 16 | 17 | /// A portal built using a logical DataFusion query plan. 18 | pub struct DataFusionPortal { 19 | df: Arc, 20 | } 21 | 22 | #[async_trait] 23 | impl Portal for DataFusionPortal { 24 | async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse> { 25 | for arrow_batch in self.df.collect().await.map_err(df_err_to_sql)? { 26 | record_batch_to_rows(&arrow_batch, batch)?; 27 | } 28 | Ok(()) 29 | } 30 | } 31 | 32 | /// An engine instance using DataFusion for catalogue management and queries. 33 | pub struct DataFusionEngine { 34 | ctx: SessionContext, 35 | } 36 | 37 | impl DataFusionEngine { 38 | /// Creates a new engine instance using the given DataFusion execution context. 39 | pub fn new(ctx: SessionContext) -> Self { 40 | Self { ctx } 41 | } 42 | } 43 | 44 | #[async_trait] 45 | impl Engine for DataFusionEngine { 46 | type PortalType = DataFusionPortal; 47 | 48 | async fn prepare(&mut self, statement: &Statement) -> Result, ErrorResponse> { 49 | let plan = self.ctx.sql(&statement.to_string()).await.map_err(df_err_to_sql)?; 50 | schema_to_field_desc(&plan.schema().clone().into()) 51 | } 52 | 53 | async fn create_portal(&mut self, statement: &Statement) -> Result { 54 | let df = self.ctx.sql(&statement.to_string()).await.map_err(df_err_to_sql)?; 55 | Ok(DataFusionPortal { df }) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /convergence/src/server.rs: -------------------------------------------------------------------------------- 1 | //! Contains utility types and functions for starting and running servers. 2 | 3 | use crate::connection::Connection; 4 | use crate::engine::Engine; 5 | use std::pin::Pin; 6 | use std::sync::Arc; 7 | use tokio::net::TcpListener; 8 | 9 | /// Controls how servers bind to local network resources. 10 | #[derive(Default)] 11 | pub struct BindOptions { 12 | addr: String, 13 | port: u16, 14 | } 15 | 16 | impl BindOptions { 17 | /// Creates a default set of options listening only on the loopback address 18 | /// using the default Postgres port of 5432. 19 | pub fn new() -> Self { 20 | Self { 21 | addr: "127.0.0.1".to_owned(), 22 | port: 5432, 23 | } 24 | } 25 | 26 | /// Sets the port to be used. 27 | pub fn with_port(mut self, port: u16) -> Self { 28 | self.port = port; 29 | self 30 | } 31 | 32 | /// Sets the address to be used. 33 | pub fn with_addr(mut self, addr: impl Into) -> Self { 34 | self.addr = addr.into(); 35 | self 36 | } 37 | 38 | /// Configures the server to listen on all interfaces rather than any specific address. 39 | pub fn use_all_interfaces(self) -> Self { 40 | self.with_addr("0.0.0.0") 41 | } 42 | } 43 | 44 | type EngineFunc = Arc Pin + Send>> + Send + Sync>; 45 | 46 | async fn run_with_listener(listener: TcpListener, engine_func: EngineFunc) -> std::io::Result<()> { 47 | loop { 48 | let (stream, _) = listener.accept().await?; 49 | let engine_func = engine_func.clone(); 50 | tokio::spawn(async move { 51 | let mut conn = Connection::new(engine_func().await); 52 | conn.run(stream).await.unwrap(); 53 | }); 54 | } 55 | } 56 | 57 | /// Starts a server using a function responsible for producing engine instances and set of bind options. 58 | /// 59 | /// Does not return unless the server terminates entirely. 60 | pub async fn run(bind: BindOptions, engine_func: EngineFunc) -> std::io::Result<()> { 61 | let listener = TcpListener::bind((bind.addr, bind.port)).await?; 62 | run_with_listener(listener, engine_func).await 63 | } 64 | 65 | /// Starts a server using a function responsible for producing engine instances and set of bind options. 66 | /// 67 | /// Returns once the server is listening for connections, with the accept loop 68 | /// running as a background task, and returns the listener's local port. 69 | /// 70 | /// Useful for creating test harnesses binding to port 0 to select a random port. 71 | pub async fn run_background(bind: BindOptions, engine_func: EngineFunc) -> std::io::Result { 72 | let listener = TcpListener::bind((bind.addr, bind.port)).await?; 73 | let port = listener.local_addr()?.port(); 74 | 75 | tokio::spawn(async move { run_with_listener(listener, engine_func).await }); 76 | 77 | Ok(port) 78 | } 79 | -------------------------------------------------------------------------------- /convergence/tests/test_connection.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use convergence::engine::{Engine, Portal}; 3 | use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState}; 4 | use convergence::protocol_ext::DataRowBatch; 5 | use convergence::server::{self, BindOptions}; 6 | use sqlparser::ast::{Expr, SelectItem, SetExpr, Statement}; 7 | use std::sync::Arc; 8 | use tokio_postgres::{connect, NoTls, SimpleQueryMessage}; 9 | 10 | struct ReturnSingleScalarPortal; 11 | 12 | #[async_trait] 13 | impl Portal for ReturnSingleScalarPortal { 14 | async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse> { 15 | let mut row = batch.create_row(); 16 | row.write_int4(1); 17 | Ok(()) 18 | } 19 | } 20 | 21 | struct ReturnSingleScalarEngine; 22 | 23 | #[async_trait] 24 | impl Engine for ReturnSingleScalarEngine { 25 | type PortalType = ReturnSingleScalarPortal; 26 | 27 | async fn prepare(&mut self, statement: &Statement) -> Result, ErrorResponse> { 28 | if let Statement::Query(query) = &statement { 29 | if let SetExpr::Select(select) = &query.body { 30 | if select.projection.len() == 1 { 31 | if let SelectItem::UnnamedExpr(Expr::Identifier(column_name)) = &select.projection[0] { 32 | match column_name.value.as_str() { 33 | "test_error" => return Err(ErrorResponse::error(SqlState::DATA_EXCEPTION, "test error")), 34 | "test_fatal" => return Err(ErrorResponse::fatal(SqlState::DATA_EXCEPTION, "fatal error")), 35 | _ => (), 36 | } 37 | } 38 | } 39 | } 40 | } 41 | 42 | Ok(vec![FieldDescription { 43 | name: "test".to_owned(), 44 | data_type: DataTypeOid::Int4, 45 | }]) 46 | } 47 | 48 | async fn create_portal(&mut self, _: &Statement) -> Result { 49 | Ok(ReturnSingleScalarPortal) 50 | } 51 | } 52 | 53 | async fn setup() -> tokio_postgres::Client { 54 | let port = server::run_background( 55 | BindOptions::new().with_port(0), 56 | Arc::new(|| Box::pin(async { ReturnSingleScalarEngine })), 57 | ) 58 | .await 59 | .unwrap(); 60 | 61 | let (client, conn) = connect(&format!("postgres://localhost:{}/test", port), NoTls) 62 | .await 63 | .expect("failed to init client"); 64 | 65 | tokio::spawn(async move { conn.await.unwrap() }); 66 | 67 | client 68 | } 69 | 70 | #[tokio::test] 71 | async fn extended_query_flow() { 72 | let client = setup().await; 73 | let row = client.query_one("select 1", &[]).await.unwrap(); 74 | let value: i32 = row.get(0); 75 | assert_eq!(value, 1); 76 | } 77 | 78 | #[tokio::test] 79 | async fn simple_query_flow() { 80 | let client = setup().await; 81 | let messages = client.simple_query("select 1").await.unwrap(); 82 | assert_eq!(messages.len(), 2); 83 | 84 | let row = match &messages[0] { 85 | SimpleQueryMessage::Row(row) => row, 86 | _ => panic!("expected row"), 87 | }; 88 | 89 | assert_eq!(row.get(0), Some("1")); 90 | 91 | let num_rows = match &messages[1] { 92 | SimpleQueryMessage::CommandComplete(rows) => *rows, 93 | _ => panic!("expected command complete"), 94 | }; 95 | 96 | assert_eq!(num_rows, 1); 97 | } 98 | 99 | #[tokio::test] 100 | async fn error_handling() { 101 | let client = setup().await; 102 | let err = client 103 | .query_one("select test_error from blah", &[]) 104 | .await 105 | .expect_err("expected error in query"); 106 | 107 | assert_eq!(err.code().unwrap().code(), SqlState::DATA_EXCEPTION.0); 108 | } 109 | 110 | #[tokio::test] 111 | async fn set_variable_noop() { 112 | let client = setup().await; 113 | client 114 | .simple_query("set somevar to 'my_val'") 115 | .await 116 | .expect("failed to set var"); 117 | } 118 | 119 | #[tokio::test] 120 | async fn empty_simple_query() { 121 | let client = setup().await; 122 | client.simple_query("").await.unwrap(); 123 | } 124 | 125 | #[tokio::test] 126 | async fn empty_extended_query() { 127 | let client = setup().await; 128 | client.query("", &[]).await.unwrap(); 129 | } 130 | -------------------------------------------------------------------------------- /convergence-arrow/tests/test_arrow.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use chrono::{NaiveDate, NaiveDateTime}; 3 | use convergence::engine::{Engine, Portal}; 4 | use convergence::protocol::{ErrorResponse, FieldDescription}; 5 | use convergence::protocol_ext::DataRowBatch; 6 | use convergence::server::{self, BindOptions}; 7 | use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc}; 8 | use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, TimestampSecondArray}; 9 | use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; 10 | use datafusion::arrow::record_batch::RecordBatch; 11 | use sqlparser::ast::Statement; 12 | use std::sync::Arc; 13 | use tokio_postgres::{connect, NoTls}; 14 | 15 | struct ArrowPortal { 16 | batch: RecordBatch, 17 | } 18 | 19 | #[async_trait] 20 | impl Portal for ArrowPortal { 21 | async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse> { 22 | record_batch_to_rows(&self.batch, batch) 23 | } 24 | } 25 | 26 | struct ArrowEngine { 27 | batch: RecordBatch, 28 | } 29 | 30 | impl ArrowEngine { 31 | fn new() -> Self { 32 | let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; 33 | let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef; 34 | let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; 35 | let ts_col = Arc::new(TimestampSecondArray::from_vec( 36 | vec![1577836800, 1580515200, 1583020800], 37 | None, 38 | )) as ArrayRef; 39 | let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef; 40 | 41 | let schema = Schema::new(vec![ 42 | Field::new("int_col", DataType::Int32, true), 43 | Field::new("float_col", DataType::Float32, true), 44 | Field::new("string_col", DataType::Utf8, true), 45 | Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true), 46 | Field::new("date_col", DataType::Date32, true), 47 | ]); 48 | 49 | Self { 50 | batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, ts_col, date_col]) 51 | .expect("failed to create batch"), 52 | } 53 | } 54 | } 55 | 56 | #[async_trait] 57 | impl Engine for ArrowEngine { 58 | type PortalType = ArrowPortal; 59 | 60 | async fn prepare(&mut self, _: &Statement) -> Result, ErrorResponse> { 61 | schema_to_field_desc(&self.batch.schema()) 62 | } 63 | 64 | async fn create_portal(&mut self, _: &Statement) -> Result { 65 | Ok(ArrowPortal { 66 | batch: self.batch.clone(), 67 | }) 68 | } 69 | } 70 | 71 | async fn setup() -> tokio_postgres::Client { 72 | let port = server::run_background( 73 | BindOptions::new().with_port(0), 74 | Arc::new(|| Box::pin(async { ArrowEngine::new() })), 75 | ) 76 | .await 77 | .unwrap(); 78 | 79 | let (client, conn) = connect(&format!("postgres://localhost:{}/test", port), NoTls) 80 | .await 81 | .expect("failed to init client"); 82 | 83 | tokio::spawn(async move { conn.await.unwrap() }); 84 | 85 | client 86 | } 87 | 88 | #[tokio::test] 89 | async fn basic_data_types() { 90 | let client = setup().await; 91 | 92 | let rows = client.query("select 1", &[]).await.unwrap(); 93 | let get_row = |idx: usize| { 94 | let row = &rows[idx]; 95 | let cols: (i32, f32, &str, NaiveDateTime, NaiveDate) = 96 | (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4)); 97 | cols 98 | }; 99 | 100 | assert_eq!( 101 | get_row(0), 102 | ( 103 | 1, 104 | 1.5, 105 | "a", 106 | NaiveDate::from_ymd(2020, 1, 1).and_hms(0, 0, 0), 107 | NaiveDate::from_ymd(1970, 1, 1), 108 | ) 109 | ); 110 | assert_eq!( 111 | get_row(1), 112 | ( 113 | 2, 114 | 2.5, 115 | "b", 116 | NaiveDate::from_ymd(2020, 2, 1).and_hms(0, 0, 0), 117 | NaiveDate::from_ymd(1970, 1, 2) 118 | ) 119 | ); 120 | assert_eq!( 121 | get_row(2), 122 | ( 123 | 3, 124 | 3.5, 125 | "c", 126 | NaiveDate::from_ymd(2020, 3, 1).and_hms(0, 0, 0), 127 | NaiveDate::from_ymd(1970, 1, 3) 128 | ) 129 | ); 130 | } 131 | -------------------------------------------------------------------------------- /convergence/src/protocol_ext.rs: -------------------------------------------------------------------------------- 1 | //! Contains extensions that make working with the Postgres protocol simpler or more efficient. 2 | 3 | use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; 4 | use bytes::{BufMut, BytesMut}; 5 | use chrono::{NaiveDate, NaiveDateTime}; 6 | use tokio_util::codec::Encoder; 7 | 8 | /// Supports batched rows for e.g. returning portal result sets. 9 | /// 10 | /// NB: this struct only performs limited validation of column consistency across rows. 11 | pub struct DataRowBatch { 12 | format_code: FormatCode, 13 | num_cols: usize, 14 | num_rows: usize, 15 | data: BytesMut, 16 | row: BytesMut, 17 | } 18 | 19 | impl DataRowBatch { 20 | /// Creates a new row batch using the given format code, requiring a certain number of columns per row. 21 | pub fn new(format_code: FormatCode, num_cols: usize) -> Self { 22 | Self { 23 | format_code, 24 | num_cols, 25 | num_rows: 0, 26 | data: BytesMut::new(), 27 | row: BytesMut::new(), 28 | } 29 | } 30 | 31 | /// Creates a [DataRowBatch] from the given [RowDescription]. 32 | pub fn from_row_desc(desc: &RowDescription) -> Self { 33 | Self::new(desc.format_code, desc.fields.len()) 34 | } 35 | 36 | /// Starts writing a new row. 37 | /// 38 | /// Returns a [DataRowWriter] that is responsible for the actual value encoding. 39 | pub fn create_row(&mut self) -> DataRowWriter { 40 | self.num_rows += 1; 41 | DataRowWriter::new(self) 42 | } 43 | 44 | /// Returns the number of rows currently written to this batch. 45 | pub fn num_rows(&self) -> usize { 46 | self.num_rows 47 | } 48 | } 49 | 50 | macro_rules! primitive_write { 51 | ($name: ident, $type: ident) => { 52 | #[allow(missing_docs)] 53 | pub fn $name(&mut self, val: $type) { 54 | match self.parent.format_code { 55 | FormatCode::Text => self.write_value(&val.to_string().into_bytes()), 56 | FormatCode::Binary => self.write_value(&val.to_be_bytes()), 57 | }; 58 | } 59 | }; 60 | } 61 | 62 | /// Temporarily leased from a [DataRowBatch] to encode a single row. 63 | pub struct DataRowWriter<'a> { 64 | current_col: usize, 65 | parent: &'a mut DataRowBatch, 66 | } 67 | 68 | impl<'a> DataRowWriter<'a> { 69 | fn new(parent: &'a mut DataRowBatch) -> Self { 70 | parent.row.put_i16(parent.num_cols as i16); 71 | Self { current_col: 0, parent } 72 | } 73 | 74 | fn write_value(&mut self, data: &[u8]) { 75 | self.current_col += 1; 76 | self.parent.row.put_i32(data.len() as i32); 77 | self.parent.row.put_slice(data); 78 | } 79 | 80 | /// Writes a null value for the next column. 81 | pub fn write_null(&mut self) { 82 | self.current_col += 1; 83 | self.parent.row.put_i32(-1); 84 | } 85 | 86 | /// Writes a string value for the next column. 87 | pub fn write_string(&mut self, val: &str) { 88 | self.write_value(val.as_bytes()); 89 | } 90 | 91 | /// Writes a bool value for the next column. 92 | pub fn write_bool(&mut self, val: bool) { 93 | match self.parent.format_code { 94 | FormatCode::Text => self.write_value(if val { "t" } else { "f" }.as_bytes()), 95 | FormatCode::Binary => { 96 | self.current_col += 1; 97 | self.parent.row.put_u8(val as u8); 98 | } 99 | }; 100 | } 101 | 102 | fn pg_date_epoch() -> NaiveDate { 103 | NaiveDate::from_ymd(2000, 1, 1) 104 | } 105 | 106 | /// Writes a date value for the next column. 107 | pub fn write_date(&mut self, val: NaiveDate) { 108 | match self.parent.format_code { 109 | FormatCode::Binary => self.write_int4(val.signed_duration_since(Self::pg_date_epoch()).num_days() as i32), 110 | FormatCode::Text => self.write_string(&val.to_string()), 111 | } 112 | } 113 | 114 | /// Writes a timestamp value for the next column. 115 | pub fn write_timestamp(&mut self, val: NaiveDateTime) { 116 | match self.parent.format_code { 117 | FormatCode::Binary => { 118 | self.write_int8( 119 | val.signed_duration_since(Self::pg_date_epoch().and_hms(0, 0, 0)) 120 | .num_microseconds() 121 | .unwrap(), 122 | ); 123 | } 124 | FormatCode::Text => self.write_string(&val.to_string()), 125 | } 126 | } 127 | 128 | primitive_write!(write_int2, i16); 129 | primitive_write!(write_int4, i32); 130 | primitive_write!(write_int8, i64); 131 | primitive_write!(write_float4, f32); 132 | primitive_write!(write_float8, f64); 133 | } 134 | 135 | impl<'a> Drop for DataRowWriter<'a> { 136 | fn drop(&mut self) { 137 | assert_eq!( 138 | self.parent.num_cols, self.current_col, 139 | "dropped a row writer with an invalid number of columns" 140 | ); 141 | 142 | self.parent.data.put_u8(b'D'); 143 | self.parent.data.put_i32((self.parent.row.len() + 4) as i32); 144 | self.parent.data.extend(self.parent.row.split()); 145 | } 146 | } 147 | 148 | impl Encoder for ConnectionCodec { 149 | type Error = ProtocolError; 150 | 151 | fn encode(&mut self, item: DataRowBatch, dst: &mut BytesMut) -> Result<(), Self::Error> { 152 | dst.extend(item.data); 153 | Ok(()) 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /convergence-arrow/src/table.rs: -------------------------------------------------------------------------------- 1 | //! Utilities for converting between Arrow and Postgres formats. 2 | 3 | use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState}; 4 | use convergence::protocol_ext::DataRowBatch; 5 | use datafusion::arrow::array::{ 6 | BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, 7 | Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, 8 | TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, 9 | }; 10 | use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit}; 11 | use datafusion::arrow::record_batch::RecordBatch; 12 | 13 | macro_rules! array_cast { 14 | ($arrtype: ident, $arr: expr) => { 15 | $arr.as_any().downcast_ref::<$arrtype>().expect("array cast failed") 16 | }; 17 | } 18 | 19 | macro_rules! array_val { 20 | ($arrtype: ident, $arr: expr, $idx: expr, $func: ident) => { 21 | array_cast!($arrtype, $arr).$func($idx) 22 | }; 23 | ($arrtype: ident, $arr: expr, $idx: expr) => { 24 | array_val!($arrtype, $arr, $idx, value) 25 | }; 26 | } 27 | 28 | /// Writes the contents of an Arrow [RecordBatch] into a Postgres [DataRowBatch]. 29 | pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBatch) -> Result<(), ErrorResponse> { 30 | for row_idx in 0..arrow_batch.num_rows() { 31 | let mut row = pg_batch.create_row(); 32 | for col_idx in 0..arrow_batch.num_columns() { 33 | let col = arrow_batch.column(col_idx); 34 | if col.is_null(row_idx) { 35 | row.write_null(); 36 | } else { 37 | match col.data_type() { 38 | DataType::Boolean => row.write_bool(array_val!(BooleanArray, col, row_idx) as bool), 39 | DataType::Int8 => row.write_int2(array_val!(Int8Array, col, row_idx) as i16), 40 | DataType::Int16 => row.write_int2(array_val!(Int16Array, col, row_idx)), 41 | DataType::Int32 => row.write_int4(array_val!(Int32Array, col, row_idx)), 42 | DataType::Int64 => row.write_int8(array_val!(Int64Array, col, row_idx)), 43 | DataType::UInt8 => row.write_int2(array_val!(UInt8Array, col, row_idx) as i16), 44 | DataType::UInt16 => row.write_int2(array_val!(UInt16Array, col, row_idx) as i16), 45 | DataType::UInt32 => row.write_int4(array_val!(UInt32Array, col, row_idx) as i32), 46 | DataType::UInt64 => row.write_int8(array_val!(UInt64Array, col, row_idx) as i64), 47 | DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()), 48 | DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)), 49 | DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)), 50 | DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)), 51 | DataType::Date32 => { 52 | row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| { 53 | ErrorResponse::error(SqlState::INVALID_DATETIME_FORMAT, "unsupported date type") 54 | })?) 55 | } 56 | DataType::Date64 => { 57 | row.write_date(array_val!(Date64Array, col, row_idx, value_as_date).ok_or_else(|| { 58 | ErrorResponse::error(SqlState::INVALID_DATETIME_FORMAT, "unsupported date type") 59 | })?) 60 | } 61 | DataType::Timestamp(unit, None) => row.write_timestamp( 62 | match unit { 63 | TimeUnit::Second => array_val!(TimestampSecondArray, col, row_idx, value_as_datetime), 64 | TimeUnit::Millisecond => { 65 | array_val!(TimestampMillisecondArray, col, row_idx, value_as_datetime) 66 | } 67 | TimeUnit::Microsecond => { 68 | array_val!(TimestampMicrosecondArray, col, row_idx, value_as_datetime) 69 | } 70 | TimeUnit::Nanosecond => { 71 | array_val!(TimestampNanosecondArray, col, row_idx, value_as_datetime) 72 | } 73 | } 74 | .ok_or_else(|| { 75 | ErrorResponse::error(SqlState::INVALID_DATETIME_FORMAT, "unsupported timestamp type") 76 | })?, 77 | ), 78 | other => { 79 | return Err(ErrorResponse::error( 80 | SqlState::FEATURE_NOT_SUPPORTED, 81 | format!("arrow to pg conversion not implemented for {}", other), 82 | )) 83 | } 84 | }; 85 | } 86 | } 87 | } 88 | 89 | Ok(()) 90 | } 91 | 92 | /// Converts an Arrow [DataType] into a Postgres [DataTypeOid]. 93 | pub fn data_type_to_oid(ty: &DataType) -> Result { 94 | Ok(match ty { 95 | DataType::Boolean => DataTypeOid::Bool, 96 | DataType::Int8 | DataType::Int16 => DataTypeOid::Int2, 97 | DataType::Int32 => DataTypeOid::Int4, 98 | DataType::Int64 => DataTypeOid::Int8, 99 | // TODO: need to figure out a sensible mapping for unsigned 100 | DataType::UInt8 | DataType::UInt16 => DataTypeOid::Int2, 101 | DataType::UInt32 => DataTypeOid::Int4, 102 | DataType::UInt64 => DataTypeOid::Int8, 103 | DataType::Float16 | DataType::Float32 => DataTypeOid::Float4, 104 | DataType::Float64 => DataTypeOid::Float8, 105 | DataType::Utf8 => DataTypeOid::Text, 106 | DataType::Date32 | DataType::Date64 => DataTypeOid::Date, 107 | DataType::Timestamp(_, None) => DataTypeOid::Timestamp, 108 | other => { 109 | return Err(ErrorResponse::error( 110 | SqlState::FEATURE_NOT_SUPPORTED, 111 | format!("arrow to pg conversion not implemented for {}", other), 112 | )) 113 | } 114 | }) 115 | } 116 | 117 | /// Converts an Arrow [Schema] into a vector of Postgres [FieldDescription] instances. 118 | pub fn schema_to_field_desc(schema: &Schema) -> Result, ErrorResponse> { 119 | schema 120 | .fields() 121 | .iter() 122 | .map(|f| { 123 | Ok(FieldDescription { 124 | name: f.name().clone(), 125 | data_type: data_type_to_oid(f.data_type())?, 126 | }) 127 | }) 128 | .collect() 129 | } 130 | -------------------------------------------------------------------------------- /convergence/src/connection.rs: -------------------------------------------------------------------------------- 1 | //! Contains the [Connection] struct, which represents an individual Postgres session, and related types. 2 | 3 | use crate::engine::{Engine, Portal}; 4 | use crate::protocol::*; 5 | use crate::protocol_ext::DataRowBatch; 6 | use futures::{SinkExt, StreamExt}; 7 | use sqlparser::ast::Statement; 8 | use sqlparser::dialect::PostgreSqlDialect; 9 | use sqlparser::parser::Parser; 10 | use std::collections::HashMap; 11 | use tokio::io::{AsyncRead, AsyncWrite}; 12 | use tokio_util::codec::Framed; 13 | 14 | /// Describes an error that may or may not result in the termination of a connection. 15 | #[derive(thiserror::Error, Debug)] 16 | pub enum ConnectionError { 17 | /// A protocol error was encountered, e.g. an invalid message for a connection's current state. 18 | #[error("protocol error: {0}")] 19 | Protocol(#[from] ProtocolError), 20 | /// A Postgres error containing a SqlState code and message occurred. 21 | /// May result in connection termination depending on the severity. 22 | #[error("error response: {0}")] 23 | ErrorResponse(#[from] ErrorResponse), 24 | /// The connection was closed. 25 | /// This always implies connection termination. 26 | #[error("connection closed")] 27 | ConnectionClosed, 28 | } 29 | 30 | #[derive(Debug)] 31 | enum ConnectionState { 32 | Startup, 33 | Idle, 34 | } 35 | 36 | #[derive(Debug, Clone)] 37 | struct PreparedStatement { 38 | pub statement: Option, 39 | pub fields: Vec, 40 | } 41 | 42 | struct BoundPortal { 43 | pub portal: E::PortalType, 44 | pub row_desc: RowDescription, 45 | } 46 | 47 | /// Describes a connection using a specific engine. 48 | /// Contains connection state including prepared statements and portals. 49 | pub struct Connection { 50 | engine: E, 51 | state: ConnectionState, 52 | statements: HashMap, 53 | portals: HashMap>>, 54 | } 55 | 56 | impl Connection { 57 | /// Create a new connection from an engine instance. 58 | pub fn new(engine: E) -> Self { 59 | Self { 60 | state: ConnectionState::Startup, 61 | statements: HashMap::new(), 62 | portals: HashMap::new(), 63 | engine, 64 | } 65 | } 66 | 67 | fn prepared_statement(&self, name: &str) -> Result<&PreparedStatement, ConnectionError> { 68 | Ok(self 69 | .statements 70 | .get(name) 71 | .ok_or_else(|| ErrorResponse::error(SqlState::INVALID_SQL_STATEMENT_NAME, "missing statement"))?) 72 | } 73 | 74 | fn portal(&self, name: &str) -> Result<&Option>, ConnectionError> { 75 | Ok(self 76 | .portals 77 | .get(name) 78 | .ok_or_else(|| ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, "missing portal"))?) 79 | } 80 | 81 | fn portal_mut(&mut self, name: &str) -> Result<&mut Option>, ConnectionError> { 82 | Ok(self 83 | .portals 84 | .get_mut(name) 85 | .ok_or_else(|| ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, "missing portal"))?) 86 | } 87 | 88 | fn parse_statement(&mut self, text: &str) -> Result, ErrorResponse> { 89 | let statements = Parser::parse_sql(&PostgreSqlDialect {}, text) 90 | .map_err(|err| ErrorResponse::error(SqlState::SYNTAX_ERROR, err.to_string()))?; 91 | 92 | match statements.len() { 93 | 0 => Ok(None), 94 | 1 => Ok(Some(statements[0].clone())), 95 | _ => Err(ErrorResponse::error( 96 | SqlState::SYNTAX_ERROR, 97 | "expected zero or one statements", 98 | )), 99 | } 100 | } 101 | 102 | async fn step( 103 | &mut self, 104 | framed: &mut Framed, 105 | ) -> Result, ConnectionError> { 106 | match self.state { 107 | ConnectionState::Startup => { 108 | match framed.next().await.ok_or(ConnectionError::ConnectionClosed)?? { 109 | ClientMessage::Startup(_startup) => { 110 | // do startup stuff 111 | } 112 | ClientMessage::SSLRequest => { 113 | // we don't support SSL for now 114 | // client will retry with startup packet 115 | framed.send('N').await?; 116 | return Ok(Some(ConnectionState::Startup)); 117 | } 118 | _ => { 119 | return Err( 120 | ErrorResponse::fatal(SqlState::PROTOCOL_VIOLATION, "expected startup message").into(), 121 | ) 122 | } 123 | } 124 | 125 | framed.send(AuthenticationOk).await?; 126 | 127 | let param_statuses = &[ 128 | ("server_version", "13"), 129 | ("server_encoding", "UTF8"), 130 | ("client_encoding", "UTF8"), 131 | ("DateStyle", "ISO"), 132 | ("TimeZone", "UTC"), 133 | ("integer_datetimes", "on"), 134 | ]; 135 | 136 | for &(param, status) in param_statuses { 137 | framed.send(ParameterStatus::new(param, status)).await?; 138 | } 139 | 140 | framed.send(ReadyForQuery).await?; 141 | Ok(Some(ConnectionState::Idle)) 142 | } 143 | ConnectionState::Idle => { 144 | match framed.next().await.ok_or(ConnectionError::ConnectionClosed)?? { 145 | ClientMessage::Parse(parse) => { 146 | let parsed_statement = self.parse_statement(&parse.query)?; 147 | 148 | self.statements.insert( 149 | parse.prepared_statement_name, 150 | PreparedStatement { 151 | fields: match &parsed_statement { 152 | Some(statement) => self.engine.prepare(statement).await?, 153 | None => vec![], 154 | }, 155 | statement: parsed_statement, 156 | }, 157 | ); 158 | framed.send(ParseComplete).await?; 159 | } 160 | ClientMessage::Bind(bind) => { 161 | let format_code = match bind.result_format { 162 | BindFormat::All(format) => format, 163 | BindFormat::PerColumn(_) => { 164 | return Err(ErrorResponse::error( 165 | SqlState::FEATURE_NOT_SUPPORTED, 166 | "per-column format codes not supported", 167 | ) 168 | .into()); 169 | } 170 | }; 171 | 172 | let prepared = self.prepared_statement(&bind.prepared_statement_name)?.clone(); 173 | let portal = match prepared.statement { 174 | Some(statement) => { 175 | let portal = self.engine.create_portal(&statement).await?; 176 | let row_desc = RowDescription { 177 | fields: prepared.fields.clone(), 178 | format_code, 179 | }; 180 | 181 | Some(BoundPortal { portal, row_desc }) 182 | } 183 | None => None, 184 | }; 185 | 186 | self.portals.insert(bind.portal, portal); 187 | 188 | framed.send(BindComplete).await?; 189 | } 190 | ClientMessage::Describe(Describe::PreparedStatement(ref statement_name)) => { 191 | let fields = self.prepared_statement(statement_name)?.fields.clone(); 192 | framed.send(ParameterDescription {}).await?; 193 | framed 194 | .send(RowDescription { 195 | fields, 196 | format_code: FormatCode::Text, 197 | }) 198 | .await?; 199 | } 200 | ClientMessage::Describe(Describe::Portal(ref portal_name)) => match self.portal(portal_name)? { 201 | Some(portal) => framed.send(portal.row_desc.clone()).await?, 202 | None => framed.send(NoData).await?, 203 | }, 204 | ClientMessage::Sync => { 205 | framed.send(ReadyForQuery).await?; 206 | } 207 | ClientMessage::Execute(exec) => match self.portal_mut(&exec.portal)? { 208 | Some(bound) => { 209 | let mut batch_writer = DataRowBatch::from_row_desc(&bound.row_desc); 210 | bound.portal.fetch(&mut batch_writer).await?; 211 | let num_rows = batch_writer.num_rows(); 212 | 213 | framed.send(batch_writer).await?; 214 | 215 | framed 216 | .send(CommandComplete { 217 | command_tag: format!("SELECT {}", num_rows), 218 | }) 219 | .await?; 220 | } 221 | None => { 222 | framed.send(EmptyQueryResponse).await?; 223 | } 224 | }, 225 | ClientMessage::Query(query) => { 226 | if let Some(parsed) = self.parse_statement(&query)? { 227 | let fields = self.engine.prepare(&parsed).await?; 228 | let row_desc = RowDescription { 229 | fields, 230 | format_code: FormatCode::Text, 231 | }; 232 | let mut portal = self.engine.create_portal(&parsed).await?; 233 | 234 | let mut batch_writer = DataRowBatch::from_row_desc(&row_desc); 235 | portal.fetch(&mut batch_writer).await?; 236 | let num_rows = batch_writer.num_rows(); 237 | 238 | framed.send(row_desc).await?; 239 | framed.send(batch_writer).await?; 240 | 241 | framed 242 | .send(CommandComplete { 243 | command_tag: format!("SELECT {}", num_rows), 244 | }) 245 | .await?; 246 | } else { 247 | framed.send(EmptyQueryResponse).await?; 248 | } 249 | framed.send(ReadyForQuery).await?; 250 | } 251 | ClientMessage::Terminate => return Ok(None), 252 | _ => return Err(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, "unexpected message").into()), 253 | }; 254 | 255 | Ok(Some(ConnectionState::Idle)) 256 | } 257 | } 258 | } 259 | 260 | /// Given a stream (typically TCP), extract Postgres protocol messages and respond accordingly. 261 | /// This function only returns when the connection is closed (either gracefully or due to an error). 262 | pub async fn run(&mut self, stream: impl AsyncRead + AsyncWrite + Unpin) -> Result<(), ConnectionError> { 263 | let mut framed = Framed::new(stream, ConnectionCodec::new()); 264 | loop { 265 | let new_state = match self.step(&mut framed).await { 266 | Ok(Some(state)) => state, 267 | Ok(None) => return Ok(()), 268 | Err(ConnectionError::ErrorResponse(err_info)) => { 269 | framed.send(err_info.clone()).await?; 270 | 271 | if err_info.severity == Severity::FATAL { 272 | return Err(err_info.into()); 273 | } 274 | 275 | framed.send(ReadyForQuery).await?; 276 | ConnectionState::Idle 277 | } 278 | Err(err) => { 279 | framed 280 | .send(ErrorResponse::fatal(SqlState::CONNECTION_EXCEPTION, "connection error")) 281 | .await?; 282 | return Err(err); 283 | } 284 | }; 285 | 286 | self.state = new_state; 287 | } 288 | } 289 | } 290 | -------------------------------------------------------------------------------- /convergence/src/protocol.rs: -------------------------------------------------------------------------------- 1 | //! Contains types that represent the core Postgres wire protocol. 2 | 3 | // this module requires a lot more work to document 4 | // may want to build this automatically from Postgres docs if possible 5 | #![allow(missing_docs)] 6 | 7 | use bytes::{Buf, BufMut, BytesMut}; 8 | use std::convert::TryFrom; 9 | use std::fmt::Display; 10 | use std::mem::size_of; 11 | use std::{collections::HashMap, convert::TryInto}; 12 | use tokio_util::codec::{Decoder, Encoder}; 13 | 14 | macro_rules! data_types { 15 | ($($name:ident = $oid:expr, $size: expr)*) => { 16 | #[derive(Debug, Copy, Clone)] 17 | /// Describes a Postgres data type. 18 | pub enum DataTypeOid { 19 | $( 20 | #[allow(missing_docs)] 21 | $name, 22 | )* 23 | /// A type which is not known to this crate. 24 | Unknown(u32), 25 | } 26 | 27 | impl DataTypeOid { 28 | /// Fetch the size in bytes for this data type. 29 | /// Variably-sized types return -1. 30 | pub fn size_bytes(&self) -> i16 { 31 | match self { 32 | $( 33 | Self::$name => $size, 34 | )* 35 | Self::Unknown(_) => unimplemented!(), 36 | } 37 | } 38 | } 39 | 40 | impl From for DataTypeOid { 41 | fn from(value: u32) -> Self { 42 | match value { 43 | $( 44 | $oid => Self::$name, 45 | )* 46 | other => Self::Unknown(other), 47 | } 48 | } 49 | } 50 | 51 | impl From for u32 { 52 | fn from(value: DataTypeOid) -> Self { 53 | match value { 54 | $( 55 | DataTypeOid::$name => $oid, 56 | )* 57 | DataTypeOid::Unknown(other) => other, 58 | } 59 | } 60 | } 61 | }; 62 | } 63 | 64 | // For oid see: 65 | // https://github.com/sfackler/rust-postgres/blob/master/postgres-types/src/type_gen.rs 66 | data_types! { 67 | Unspecified = 0, 0 68 | 69 | Bool = 16, 1 70 | 71 | Int2 = 21, 2 72 | Int4 = 23, 4 73 | Int8 = 20, 8 74 | 75 | Float4 = 700, 4 76 | Float8 = 701, 8 77 | 78 | Date = 1082, 4 79 | Timestamp = 1114, 8 80 | 81 | Text = 25, -1 82 | } 83 | 84 | /// Describes how to format a given value or set of values. 85 | #[derive(Debug, Copy, Clone)] 86 | pub enum FormatCode { 87 | /// Use the stable text representation. 88 | Text = 0, 89 | /// Use the less-stable binary representation. 90 | Binary = 1, 91 | } 92 | 93 | impl TryFrom for FormatCode { 94 | type Error = ProtocolError; 95 | 96 | fn try_from(value: i16) -> Result { 97 | match value { 98 | 0 => Ok(FormatCode::Text), 99 | 1 => Ok(FormatCode::Binary), 100 | other => Err(ProtocolError::InvalidFormatCode(other)), 101 | } 102 | } 103 | } 104 | 105 | #[derive(Debug)] 106 | pub struct Startup { 107 | pub requested_protocol_version: (i16, i16), 108 | pub parameters: HashMap, 109 | } 110 | 111 | #[derive(Debug)] 112 | pub enum Describe { 113 | Portal(String), 114 | PreparedStatement(String), 115 | } 116 | 117 | #[derive(Debug)] 118 | pub struct Parse { 119 | pub prepared_statement_name: String, 120 | pub query: String, 121 | pub parameter_types: Vec, 122 | } 123 | 124 | #[derive(Debug)] 125 | pub enum BindFormat { 126 | All(FormatCode), 127 | PerColumn(Vec), 128 | } 129 | 130 | #[derive(Debug)] 131 | pub struct Bind { 132 | pub portal: String, 133 | pub prepared_statement_name: String, 134 | pub result_format: BindFormat, 135 | } 136 | 137 | #[derive(Debug)] 138 | pub struct Execute { 139 | pub portal: String, 140 | pub max_rows: Option, 141 | } 142 | 143 | #[derive(Debug)] 144 | pub enum ClientMessage { 145 | SSLRequest, // for SSL negotiation 146 | Startup(Startup), 147 | Parse(Parse), 148 | Describe(Describe), 149 | Bind(Bind), 150 | Sync, 151 | Execute(Execute), 152 | Query(String), 153 | Terminate, 154 | } 155 | 156 | pub trait BackendMessage: std::fmt::Debug { 157 | const TAG: u8; 158 | 159 | fn encode(&self, dst: &mut BytesMut); 160 | } 161 | 162 | #[derive(Debug, Clone, PartialEq, Eq)] 163 | pub struct SqlState(pub &'static str); 164 | 165 | impl SqlState { 166 | pub const SUCCESSFUL_COMPLETION: SqlState = SqlState("00000"); 167 | pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState("0A000"); 168 | pub const INVALID_CURSOR_NAME: SqlState = SqlState("34000"); 169 | pub const CONNECTION_EXCEPTION: SqlState = SqlState("08000"); 170 | pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState("26000"); 171 | pub const DATA_EXCEPTION: SqlState = SqlState("22000"); 172 | pub const PROTOCOL_VIOLATION: SqlState = SqlState("08P01"); 173 | pub const SYNTAX_ERROR: SqlState = SqlState("42601"); 174 | pub const INVALID_DATETIME_FORMAT: SqlState = SqlState("22007"); 175 | } 176 | 177 | #[derive(Debug, Clone, PartialEq, Eq)] 178 | pub struct Severity(pub &'static str); 179 | 180 | impl Severity { 181 | pub const ERROR: Severity = Severity("ERROR"); 182 | pub const FATAL: Severity = Severity("FATAL"); 183 | } 184 | 185 | #[derive(thiserror::Error, Debug, Clone)] 186 | pub struct ErrorResponse { 187 | pub sql_state: SqlState, 188 | pub severity: Severity, 189 | pub message: String, 190 | } 191 | 192 | impl ErrorResponse { 193 | pub fn new(sql_state: SqlState, severity: Severity, message: impl Into) -> Self { 194 | ErrorResponse { 195 | sql_state, 196 | severity, 197 | message: message.into(), 198 | } 199 | } 200 | 201 | pub fn error(sql_state: SqlState, message: impl Into) -> Self { 202 | Self::new(sql_state, Severity::ERROR, message) 203 | } 204 | 205 | pub fn fatal(sql_state: SqlState, message: impl Into) -> Self { 206 | Self::new(sql_state, Severity::FATAL, message) 207 | } 208 | } 209 | 210 | impl Display for ErrorResponse { 211 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 212 | write!(f, "error") 213 | } 214 | } 215 | 216 | impl BackendMessage for ErrorResponse { 217 | const TAG: u8 = b'E'; 218 | 219 | fn encode(&self, dst: &mut BytesMut) { 220 | dst.put_u8(b'C'); 221 | dst.put_slice(self.sql_state.0.as_bytes()); 222 | dst.put_u8(0); 223 | dst.put_u8(b'S'); 224 | dst.put_slice(self.severity.0.as_bytes()); 225 | dst.put_u8(0); 226 | dst.put_u8(b'M'); 227 | dst.put_slice(self.message.as_bytes()); 228 | dst.put_u8(0); 229 | 230 | dst.put_u8(0); // tag 231 | } 232 | } 233 | 234 | #[derive(Debug)] 235 | pub struct ParameterDescription {} 236 | 237 | impl BackendMessage for ParameterDescription { 238 | const TAG: u8 = b't'; 239 | 240 | fn encode(&self, dst: &mut BytesMut) { 241 | dst.put_i16(0); 242 | } 243 | } 244 | 245 | #[derive(Debug, Clone)] 246 | pub struct FieldDescription { 247 | pub name: String, 248 | pub data_type: DataTypeOid, 249 | } 250 | 251 | #[derive(Debug, Clone)] 252 | pub struct RowDescription { 253 | pub fields: Vec, 254 | pub format_code: FormatCode, 255 | } 256 | 257 | impl BackendMessage for RowDescription { 258 | const TAG: u8 = b'T'; 259 | 260 | fn encode(&self, dst: &mut BytesMut) { 261 | dst.put_i16(self.fields.len() as i16); 262 | for field in &self.fields { 263 | dst.put_slice(field.name.as_bytes()); 264 | dst.put_u8(0); 265 | dst.put_i32(0); // table oid 266 | dst.put_i16(0); // column attr number 267 | dst.put_u32(field.data_type.into()); 268 | dst.put_i16(field.data_type.size_bytes()); 269 | dst.put_i32(-1); // data type modifier 270 | dst.put_i16(self.format_code as i16); 271 | } 272 | } 273 | } 274 | 275 | #[derive(Debug)] 276 | pub struct AuthenticationOk; 277 | 278 | impl BackendMessage for AuthenticationOk { 279 | const TAG: u8 = b'R'; 280 | 281 | fn encode(&self, dst: &mut BytesMut) { 282 | dst.put_i32(0); 283 | } 284 | } 285 | 286 | #[derive(Debug)] 287 | pub struct ReadyForQuery; 288 | 289 | impl BackendMessage for ReadyForQuery { 290 | const TAG: u8 = b'Z'; 291 | 292 | fn encode(&self, dst: &mut BytesMut) { 293 | dst.put_u8(b'I'); 294 | } 295 | } 296 | 297 | #[derive(Debug)] 298 | pub struct ParseComplete; 299 | 300 | impl BackendMessage for ParseComplete { 301 | const TAG: u8 = b'1'; 302 | 303 | fn encode(&self, _dst: &mut BytesMut) {} 304 | } 305 | 306 | #[derive(Debug)] 307 | pub struct BindComplete; 308 | 309 | impl BackendMessage for BindComplete { 310 | const TAG: u8 = b'2'; 311 | 312 | fn encode(&self, _dst: &mut BytesMut) {} 313 | } 314 | 315 | #[derive(Debug)] 316 | pub struct NoData; 317 | 318 | impl BackendMessage for NoData { 319 | const TAG: u8 = b'n'; 320 | 321 | fn encode(&self, _dst: &mut BytesMut) {} 322 | } 323 | 324 | #[derive(Debug)] 325 | pub struct EmptyQueryResponse; 326 | 327 | impl BackendMessage for EmptyQueryResponse { 328 | const TAG: u8 = b'I'; 329 | 330 | fn encode(&self, _dst: &mut BytesMut) {} 331 | } 332 | 333 | #[derive(Debug)] 334 | pub struct CommandComplete { 335 | pub command_tag: String, 336 | } 337 | 338 | impl BackendMessage for CommandComplete { 339 | const TAG: u8 = b'C'; 340 | 341 | fn encode(&self, dst: &mut BytesMut) { 342 | dst.put_slice(self.command_tag.as_bytes()); 343 | dst.put_u8(0); 344 | } 345 | } 346 | 347 | #[derive(Debug)] 348 | pub struct ParameterStatus { 349 | name: String, 350 | value: String, 351 | } 352 | 353 | impl BackendMessage for ParameterStatus { 354 | const TAG: u8 = b'S'; 355 | 356 | fn encode(&self, dst: &mut BytesMut) { 357 | dst.put_slice(self.name.as_bytes()); 358 | dst.put_u8(0); 359 | dst.put_slice(self.value.as_bytes()); 360 | dst.put_u8(0); 361 | } 362 | } 363 | 364 | impl ParameterStatus { 365 | pub fn new(name: impl Into, value: impl Into) -> Self { 366 | Self { 367 | name: name.into(), 368 | value: value.into(), 369 | } 370 | } 371 | } 372 | 373 | #[derive(Default, Debug)] 374 | pub struct ConnectionCodec { 375 | // most state tracking is handled at a higher level 376 | // however, the actual wire format uses a different header for startup vs normal messages 377 | // so we need to be able to differentiate inside the decoder 378 | startup_received: bool, 379 | } 380 | 381 | impl ConnectionCodec { 382 | pub fn new() -> Self { 383 | Self { 384 | startup_received: false, 385 | } 386 | } 387 | } 388 | 389 | #[derive(thiserror::Error, Debug)] 390 | pub enum ProtocolError { 391 | #[error("io error: {0}")] 392 | Io(#[from] std::io::Error), 393 | #[error("utf8 error: {0}")] 394 | Utf8(#[from] std::string::FromUtf8Error), 395 | #[error("parsing error")] 396 | ParserError, 397 | #[error("invalid message type: {0}")] 398 | InvalidMessageType(u8), 399 | #[error("invalid format code: {0}")] 400 | InvalidFormatCode(i16), 401 | } 402 | 403 | // length prefix, two version components 404 | const STARTUP_HEADER_SIZE: usize = size_of::() + (size_of::() * 2); 405 | // message tag, length prefix 406 | const MESSAGE_HEADER_SIZE: usize = size_of::() + size_of::(); 407 | 408 | impl Decoder for ConnectionCodec { 409 | type Item = ClientMessage; 410 | type Error = ProtocolError; 411 | 412 | fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { 413 | if !self.startup_received { 414 | if src.len() < STARTUP_HEADER_SIZE { 415 | return Ok(None); 416 | } 417 | 418 | let mut header_buf = src.clone(); 419 | let message_len = header_buf.get_i32() as usize; 420 | let protocol_version_major = header_buf.get_i16(); 421 | let protocol_version_minor = header_buf.get_i16(); 422 | 423 | if protocol_version_major == 1234i16 && protocol_version_minor == 5679i16 { 424 | src.advance(STARTUP_HEADER_SIZE); 425 | return Ok(Some(ClientMessage::SSLRequest)); 426 | } 427 | 428 | if src.len() < message_len { 429 | src.reserve(message_len - src.len()); 430 | return Ok(None); 431 | } 432 | 433 | src.advance(STARTUP_HEADER_SIZE); 434 | 435 | let mut parameters = HashMap::new(); 436 | 437 | let mut param_str_start_pos = 0; 438 | let mut current_key = None; 439 | for (i, &blah) in src.iter().enumerate() { 440 | if blah == 0 { 441 | let string_value = String::from_utf8(src[param_str_start_pos..i].to_owned())?; 442 | param_str_start_pos = i + 1; 443 | 444 | current_key = match current_key { 445 | Some(key) => { 446 | parameters.insert(key, string_value); 447 | None 448 | } 449 | None => Some(string_value), 450 | } 451 | } 452 | } 453 | 454 | src.advance(message_len - STARTUP_HEADER_SIZE); 455 | 456 | self.startup_received = true; 457 | return Ok(Some(ClientMessage::Startup(Startup { 458 | requested_protocol_version: (protocol_version_major, protocol_version_minor), 459 | parameters, 460 | }))); 461 | } 462 | 463 | if src.len() < MESSAGE_HEADER_SIZE { 464 | src.reserve(MESSAGE_HEADER_SIZE); 465 | return Ok(None); 466 | } 467 | 468 | let mut header_buf = src.clone(); 469 | let message_tag = header_buf.get_u8(); 470 | let message_len = header_buf.get_i32() as usize; 471 | 472 | if src.len() < message_len { 473 | src.reserve(message_len - src.len()); 474 | return Ok(None); 475 | } 476 | 477 | src.advance(MESSAGE_HEADER_SIZE); 478 | 479 | let read_cstr = |src: &mut BytesMut| -> Result { 480 | let next_null = src.iter().position(|&b| b == 0).ok_or(ProtocolError::ParserError)?; 481 | let bytes = src[..next_null].to_owned(); 482 | src.advance(bytes.len() + 1); 483 | Ok(String::from_utf8(bytes)?) 484 | }; 485 | 486 | let message = match message_tag { 487 | b'P' => { 488 | let prepared_statement_name = read_cstr(src)?; 489 | let query = read_cstr(src)?; 490 | let num_params = src.get_i16(); 491 | let _params: Vec<_> = (0..num_params).into_iter().map(|_| src.get_u32()).collect(); 492 | 493 | ClientMessage::Parse(Parse { 494 | prepared_statement_name, 495 | query, 496 | parameter_types: Vec::new(), 497 | }) 498 | } 499 | b'D' => { 500 | let target_type = src.get_u8(); 501 | let name = read_cstr(src)?; 502 | 503 | ClientMessage::Describe(match target_type { 504 | b'P' => Describe::Portal(name), 505 | b'S' => Describe::PreparedStatement(name), 506 | _ => return Err(ProtocolError::ParserError), 507 | }) 508 | } 509 | b'S' => ClientMessage::Sync, 510 | b'B' => { 511 | let portal = read_cstr(src)?; 512 | let prepared_statement_name = read_cstr(src)?; 513 | 514 | let num_param_format_codes = src.get_i16(); 515 | for _ in 0..num_param_format_codes { 516 | let _format_code = src.get_i16(); 517 | } 518 | 519 | let num_params = src.get_i16(); 520 | for _ in 0..num_params { 521 | let param_len = src.get_i32() as usize; 522 | let _bytes = &src[0..param_len]; 523 | src.advance(param_len); 524 | } 525 | 526 | let result_format = match src.get_i16() { 527 | 0 => BindFormat::All(FormatCode::Text), 528 | 1 => BindFormat::All(src.get_i16().try_into()?), 529 | n => { 530 | let mut result_format_codes = Vec::new(); 531 | for _ in 0..n { 532 | result_format_codes.push(src.get_i16().try_into()?); 533 | } 534 | BindFormat::PerColumn(result_format_codes) 535 | } 536 | }; 537 | 538 | ClientMessage::Bind(Bind { 539 | portal, 540 | prepared_statement_name, 541 | result_format, 542 | }) 543 | } 544 | b'E' => { 545 | let portal = read_cstr(src)?; 546 | let max_rows = match src.get_i32() { 547 | 0 => None, 548 | other => Some(other), 549 | }; 550 | 551 | ClientMessage::Execute(Execute { portal, max_rows }) 552 | } 553 | b'Q' => { 554 | let query = read_cstr(src)?; 555 | ClientMessage::Query(query) 556 | } 557 | b'X' => ClientMessage::Terminate, 558 | other => return Err(ProtocolError::InvalidMessageType(other)), 559 | }; 560 | 561 | Ok(Some(message)) 562 | } 563 | } 564 | 565 | impl Encoder for ConnectionCodec { 566 | type Error = ProtocolError; 567 | 568 | fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { 569 | let mut body = BytesMut::new(); 570 | item.encode(&mut body); 571 | 572 | dst.put_u8(T::TAG); 573 | dst.put_i32((body.len() + 4) as i32); 574 | dst.put_slice(&body); 575 | Ok(()) 576 | } 577 | } 578 | 579 | impl Encoder for ConnectionCodec { 580 | type Error = ProtocolError; 581 | 582 | fn encode(&mut self, item: char, dst: &mut BytesMut) -> Result<(), Self::Error> { 583 | dst.put_u8(item as u8); 584 | Ok(()) 585 | } 586 | } 587 | --------------------------------------------------------------------------------