├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── examples ├── bench.rs ├── client.rs ├── copy.rs ├── duckdb.rs ├── gluesql.rs ├── scram.rs ├── secure_server.rs ├── server.rs ├── sqlite.rs ├── ssl │ ├── server.crt │ └── server.key └── transaction.rs ├── pgbench ├── bench.sh └── select.sql ├── release.toml ├── src ├── api │ ├── auth │ │ ├── cleartext.rs │ │ ├── md5pass.rs │ │ ├── mod.rs │ │ ├── noop.rs │ │ └── scram.rs │ ├── client │ │ ├── auth.rs │ │ ├── config.rs │ │ ├── mod.rs │ │ ├── query.rs │ │ └── result.rs │ ├── copy.rs │ ├── mod.rs │ ├── portal.rs │ ├── query.rs │ ├── results.rs │ ├── stmt.rs │ ├── store.rs │ └── transaction.rs ├── error.rs ├── lib.rs ├── messages │ ├── codec.rs │ ├── copy.rs │ ├── data.rs │ ├── extendedquery.rs │ ├── mod.rs │ ├── response.rs │ ├── simplequery.rs │ ├── startup.rs │ └── terminate.rs ├── tokio │ ├── client.rs │ ├── mod.rs │ └── server.rs └── types │ ├── from_sql_text.rs │ ├── mod.rs │ └── to_sql_text.rs └── tests-integration ├── go ├── client.go ├── go.mod └── go.sum ├── jdbc └── test.bb ├── nodejs ├── index.js ├── package-lock.json └── package.json ├── python ├── client2.py └── client3.py ├── rust-client ├── Cargo.toml └── src │ └── main.rs ├── test-server ├── Cargo.toml └── src │ └── main.rs └── test.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: sunng87 4 | liberapay: Sunng 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: cargo 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "02:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | schedule: [{cron: "30 13 * * *"}] 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | format: 11 | name: Rustfmt 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions-rs/toolchain@v1 16 | with: 17 | toolchain: nightly 18 | components: rustfmt 19 | override: true 20 | - run: cargo fmt -- --check 21 | 22 | lint: 23 | name: Clippy lint 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions-rs/toolchain@v1 28 | with: 29 | toolchain: stable 30 | components: clippy 31 | override: true 32 | - name: Lint default 33 | run: cargo clippy -- -D warnings 34 | - name: Lint minimal 35 | run: cargo clippy --no-default-features -- -D warnings 36 | - name: Lint server-api without tls 37 | run: cargo clippy --no-default-features --features server-api -- -D warnings 38 | - name: Lint ring 39 | run: cargo clippy --no-default-features --features server-api-ring -- -D warnings 40 | - name: Lint client api 41 | run: cargo clippy --features client-api-aws-lc-rs -- -D warnings 42 | - name: Lint scram 43 | run: cargo clippy --features scram -- -D warnings 44 | 45 | test: 46 | name: Test 47 | runs-on: ${{ matrix.os }} 48 | strategy: 49 | matrix: 50 | build: [stable, nightly] 51 | include: 52 | - build: stable 53 | os: ubuntu-latest 54 | rust: stable 55 | - build: nightly 56 | os: ubuntu-latest 57 | rust: nightly 58 | steps: 59 | - uses: actions/checkout@v4 60 | - uses: actions-rs/toolchain@v1 61 | with: 62 | toolchain: ${{ matrix.rust }} 63 | override: true 64 | - name: Build and run tests on default feature set 65 | run: cargo test 66 | - name: Run tests on minimal feature set 67 | run: cargo test --no-default-features 68 | - name: Run tests without tls 69 | run: cargo test --no-default-features --features server-api 70 | - name: Run tests on additional scram+ring feature set 71 | run: cargo test --no-default-features --features server-api-ring,scram 72 | - name: Run tests on additional scram+aws-lc-rs feature set 73 | run: cargo test --features scram 74 | - name: Run tests for client api 75 | run: cargo test --features client-api-aws-lc-rs 76 | - name: Run check on duckdb and sqlite example 77 | run: cargo check --all-targets --features _duckdb,_sqlite,_bundled 78 | 79 | integration: 80 | name: Integration tests 81 | runs-on: ubuntu-latest 82 | timeout-minutes: 15 83 | needs: [test] 84 | steps: 85 | - uses: actions/checkout@v4 86 | - uses: actions-rs/toolchain@v1 87 | with: 88 | toolchain: stable 89 | override: true 90 | - run: | 91 | pip install psycopg 92 | pip install psycopg2 93 | - uses: turtlequeue/setup-babashka@v1.5.0 94 | with: 95 | babashka-version: 1.1.173 96 | - run: ./tests-integration/test.sh 97 | 98 | msrv: 99 | name: MSRV 100 | runs-on: ubuntu-latest 101 | steps: 102 | - uses: actions/checkout@v4 103 | - uses: actions-rs/toolchain@v1 104 | with: 105 | toolchain: "1.75" 106 | override: true 107 | - run: cargo build 108 | - run: cargo build --no-default-features 109 | - run: cargo build --no-default-features --features server-api 110 | - run: cargo build --no-default-features --features server-api-ring 111 | - run: cargo build --features scram 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | tests-integration/test-server/target 4 | tests-integration/rust-client/target 5 | node_modules 6 | .DS_Store -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pgwire" 3 | version = "0.30.1" 4 | edition = "2021" 5 | authors = ["Ning Sun "] 6 | license = "MIT/Apache-2.0" 7 | description = "Postgresql wire protocol implemented as a library" 8 | keywords = ["database", "postgresql"] 9 | categories = ["database"] 10 | homepage = "https://github.com/sunng87/pgwire" 11 | repository = "https://github.com/sunng87/pgwire" 12 | documentation = "https://docs.rs/crate/pgwire/" 13 | readme = "README.md" 14 | rust-version = "1.75" 15 | 16 | [dependencies] 17 | derive-new = "0.7" 18 | bytes = "1.1.0" 19 | thiserror = "2" 20 | ## api 21 | tokio = { version = "1.19", features = [ 22 | "net", 23 | "rt", 24 | "io-util", 25 | "macros" 26 | ], optional = true } 27 | tokio-util = { version = "0.7.3", features = ["codec", "io"], optional = true } 28 | tokio-rustls = { version = "0.26.2", optional = true, default-features = false, features = ["logging", "tls12"]} 29 | rustls-pki-types = { version = "1.10", optional = true } 30 | futures = { version = "0.3", optional = true } 31 | async-trait = { version = "0.1", optional = true } 32 | pin-project = { version = "1.1", optional = true } 33 | rand = { version = "0.9", optional = true } 34 | md5 = { version = "0.7", optional = true } 35 | hex = { version = "0.4", optional = true } 36 | ## scram libraries 37 | base64 = { version = "0.22", optional = true } 38 | ring = { version = "0.17", optional = true } 39 | aws-lc-rs = { version = "1.7", optional = true } 40 | stringprep = { version = "0.1.2", optional = true } 41 | x509-certificate = { version = "0.24", optional = true } 42 | ## types 43 | postgres-types = { version = "0.2", features = [ 44 | "with-chrono-0_4", 45 | "array-impls", 46 | ], optional = true } 47 | chrono = { version = "0.4", features = ["std"], optional = true } 48 | rust_decimal = { version = "1.35", features = ["db-postgres"], optional = true } 49 | lazy-regex = {version = "3.3", default-features = false, features = ["lite"]} 50 | ## config 51 | percent-encoding = { version = "2.0", optional = true } 52 | 53 | [features] 54 | default = ["server-api-aws-lc-rs"] 55 | _ring = ["dep:ring", "tokio-rustls/ring", "dep:rustls-pki-types"] 56 | _aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs", "dep:rustls-pki-types"] 57 | server-api = [ 58 | "dep:tokio", 59 | "dep:tokio-util", 60 | "dep:futures", 61 | "dep:async-trait", 62 | "dep:rand", 63 | "dep:md5", 64 | "dep:hex", 65 | "dep:postgres-types", 66 | "dep:chrono", 67 | "dep:rust_decimal", 68 | ] 69 | server-api-ring = ["server-api", "_ring"] 70 | server-api-aws-lc-rs = ["server-api", "_aws-lc-rs"] 71 | client-api = [ 72 | "dep:percent-encoding", 73 | "dep:pin-project", 74 | "dep:tokio", 75 | "dep:tokio-util", 76 | "dep:futures", 77 | "dep:async-trait", 78 | "dep:md5", 79 | ] 80 | client-api-ring = ["client-api", "_ring", "dep:rustls-pki-types"] 81 | client-api-aws-lc-rs = ["client-api", "_aws-lc-rs", "dep:rustls-pki-types"] 82 | scram = ["dep:base64", "dep:stringprep", "dep:x509-certificate"] 83 | _duckdb = [] 84 | _sqlite = [] 85 | _bundled = ["duckdb/bundled", "rusqlite/bundled"] 86 | 87 | [dev-dependencies] 88 | tokio = { version = "1.19", features = ["rt-multi-thread", "net", "macros"]} 89 | rustls-pki-types = { version = "1.10" } 90 | rusqlite = { version = "0.36.0", features = ["column_decltype"] } 91 | ## for duckdb example 92 | duckdb = { version = "1.0.0" } 93 | 94 | ## for loading custom cert files 95 | rustls-pemfile = "2.0" 96 | ## webpki-roots has mozilla's set of roots 97 | ## rustls-native-certs loads roots from current system 98 | gluesql = { version = "0.16", default-features = false, features = ["gluesql_memory_storage"] } 99 | 100 | [workspace] 101 | members = [ 102 | ".", 103 | "tests-integration/rust-client", 104 | "tests-integration/test-server" 105 | ] 106 | 107 | [[example]] 108 | name = "server" 109 | required-features = ["server-api-aws-lc-rs"] 110 | 111 | [[example]] 112 | name = "secure_server" 113 | required-features = ["server-api-aws-lc-rs"] 114 | 115 | [[example]] 116 | name = "bench" 117 | required-features = ["server-api-aws-lc-rs"] 118 | 119 | [[example]] 120 | name = "gluesql" 121 | required-features = ["server-api-aws-lc-rs"] 122 | 123 | [[example]] 124 | name = "sqlite" 125 | required-features = ["server-api-aws-lc-rs", "_sqlite"] 126 | 127 | [[example]] 128 | name = "duckdb" 129 | required-features = ["server-api-aws-lc-rs", "_duckdb"] 130 | 131 | [[example]] 132 | name = "copy" 133 | required-features = ["server-api-aws-lc-rs"] 134 | 135 | [[example]] 136 | name = "scram" 137 | required-features = ["server-api-aws-lc-rs", "scram"] 138 | 139 | [[example]] 140 | name = "transaction" 141 | required-features = ["server-api-aws-lc-rs"] 142 | 143 | [[example]] 144 | name = "client" 145 | required-features = ["client-api"] 146 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2018] [Ning Sun] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Ning Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pgwire 2 | 3 | [![CI](https://github.com/sunng87/pgwire/actions/workflows/ci.yml/badge.svg)](https://github.com/sunng87/pgwire/actions/workflows/ci.yml) 4 | [![](https://img.shields.io/crates/v/pgwire)](https://crates.io/crates/pgwire) 5 | [![Docs](https://docs.rs/pgwire/badge.svg)](https://docs.rs/pgwire/latest/pgwire/) 6 | 7 | Build Postgres compatible access layer for your data service. 8 | 9 | This library implements PostgreSQL Wire Protocol, and provide essential APIs to 10 | write PostgreSQL compatible servers and clients. It's like 11 | [hyper](https://github.com/hyperium/hyper/), but for postgres wire protocol. 12 | 13 | If you are interested in related topic, you can check [project 14 | ideas](https://github.com/sunng87/pgwire/discussions/204) to build on top of 15 | this library. 16 | 17 | 18 | ## Status 19 | 20 | - Message format 21 | - [x] Frontend-Backend protocol messages 22 | - [ ] Streaming replication protocol 23 | - [ ] Logical streaming replication protocol message 24 | - [x] Backend TCP/TLS server on Tokio 25 | - [ ] Frontend TCP/TLS client on Tokio 26 | - Frontend-Backend interaction over TCP 27 | - [x] SSL Request and Response 28 | - [x] PostgreSQL 17 direct SSL negotiation 29 | - [x] Startup 30 | - [x] No authentication 31 | - [x] Clear-text password authentication 32 | - [x] Md5 Password authentication 33 | - [x] SASL SCRAM authentication (optional feature `server-api-scram-ring` or 34 | `server-api-scram-aws-lc-rs`) 35 | - [x] SCRAM-SHA-256 36 | - [x] SCRAM-SHA-256-PLUS 37 | - [x] Simple Query and Response 38 | - [x] Extended Query and Response 39 | - [x] Parse 40 | - [x] Bind 41 | - [x] Execute 42 | - [x] Describe 43 | - [x] Sync 44 | - [x] Termination 45 | - [x] Cancel 46 | - [x] Error and Notice 47 | - [x] Copy 48 | - [x] Notification 49 | - [ ] Streaming replication over TCP 50 | - [ ] Logical streaming replication over TCP 51 | - [x] Data types 52 | - [x] Text format 53 | - [x] Binary format, implemented in `postgres-types` 54 | - APIs 55 | - Backend/Server 56 | - [x] Startup APIs 57 | - [x] AuthSource API, fetching and hashing passwords 58 | - [x] Server parameters API, ready but not very good 59 | - [x] Simple Query API 60 | - [x] Extended Query API 61 | - [x] QueryParser API, for transforming prepared statement 62 | - [x] ResultSet builder/encoder API 63 | - [ ] Query Cancellation API 64 | - [x] Error and Notice API 65 | - [x] Copy API 66 | - [x] Copy-in 67 | - [x] Copy-out 68 | - [x] Copy-both 69 | - [x] Transaction state 70 | - [ ] Streaming replication over TCP 71 | - [ ] Logical streaming replication server API 72 | - Frontend/Client 73 | - [x] Startup APIs 74 | - [x] Simple Query API 75 | - [ ] Extended Query API 76 | - [ ] ResultSet decoder API 77 | - [ ] Query Cancellation API 78 | - [ ] Error and Notice API 79 | - [ ] Copy API 80 | - [ ] Transaction state 81 | - [ ] Streaming replication over TCP 82 | - [ ] Logical streaming replication server API 83 | 84 | ## About Postgres Wire Protocol 85 | 86 | Postgres Wire Protocol is a relatively general-purpose Layer-7 protocol. There 87 | are 6 parts of the protocol: 88 | 89 | - Startup: client-server handshake and authentication. 90 | - Simple Query: The text-based query protocol of postgresql. Query are provided 91 | as string, and server is allowed to stream data in response. 92 | - Extended Query: A new sub-protocol for query which has ability to cache the 93 | query on server-side and reuse it with new parameters. The response part is 94 | identical to Simple Query. 95 | - Copy: the subprotocol to copy data from and to postgresql. 96 | - Replication 97 | - Logical Replication 98 | 99 | Also note that Postgres Wire Protocol has no semantics about SQL, so literally 100 | you can use any query language, data formats or even natural language to 101 | interact with the backend. 102 | 103 | The response are always encoded as data row format. And there is a field 104 | description as header of the data to describe its name, type and format. 105 | 106 | [Jelte Fennema-Nio](https://github.com/JelteF)'s on talk on PgConf.dev 2024 has 107 | a great coverage of how the wire protocol works: 108 | https://www.youtube.com/watch?v=nh62VgNj6hY 109 | 110 | ## Usage 111 | 112 | ### Server/Backend 113 | 114 | To use `pgwire` in your server application, you will need to implement two key 115 | components: **startup processor** and **query processor**. For query 116 | processing, there are two kinds of queries: simple and extended. By adding 117 | `SimpleQueryHandler` to your application, you will get `psql` command-line tool 118 | compatibility. And for more language drivers and additional prepared statement, 119 | binary encoding support, `ExtendedQueryHandler` is required. 120 | 121 | Examples are provided to demo the very basic usage of `pgwire` on server side: 122 | 123 | - `examples/sqlite.rs`: uses an in-memory sqlite database at its core and serves 124 | it with postgresql protocol. This is a full example with both simple and 125 | extended query implementation. `cargo run --features _sqlite --example 126 | sqlite` 127 | - `examples/duckdb.rs`: similar to sqlite example but with duckdb backend. Note 128 | that not all data types are implemented in this example. `cargo run --features 129 | _duckdb --example duckdb` 130 | - `examples/gluesql.rs`: uses an in-memory 131 | [gluesql](https://github.com/gluesql/gluesql) at its core and serves 132 | it with postgresql protocol. 133 | - `examples/server.rs`: demos a server that always returns fixed results. 134 | - `examples/secure_server.rs`: demos a server with ssl support and always 135 | returns fixed results. 136 | - `examples/scram.rs`: demos how to configure more secure authentication 137 | mechanism: 138 | [SCRAM](https://en.wikipedia.org/wiki/Salted_Challenge_Response_Authentication_Mechanism) 139 | - `examples/transaction.rs`: see how to control transaction state at wire 140 | protocol level. 141 | - `examples/datafusion.rs`: Now moved to 142 | [datafusion-postgres](https://github.com/sunng87/datafusion-postgres) 143 | 144 | ### Client/Frontend 145 | 146 | The client/frontend API is progress of development. This API will focus on 147 | providing full access of postgres wire protocol. It's designed to build 148 | components like postgres proxy. For general purpose postgres driver for 149 | application development, you can use 150 | [rust-postgres](https://github.com/sfackler/rust-postgres). 151 | 152 | ## Projects using pgwire 153 | 154 | * [GreptimeDB](https://github.com/GrepTimeTeam/greptimedb): Cloud-native 155 | time-series database 156 | * [risinglight](https://github.com/risinglightdb/risinglight): OLAP database 157 | system for educational purpose 158 | * [PeerDB](https://github.com/PeerDB-io/peerdb) Postgres first ETL/ELT, enabling 159 | 10x faster data movement in and out of Postgres 160 | * [CeresDB](https://github.com/CeresDB/ceresdb) CeresDB is a high-performance, 161 | distributed, cloud native time-series database from AntGroup. 162 | * [dozer](https://github.com/getdozer/dozer) a real-time data platform for 163 | building, deploying and maintaining data products. 164 | * [restate](https://github.com/restatedev/restate) Framework for building 165 | resilient workflow 166 | 167 | Submit a pull request if your project isn't listed here. 168 | 169 | ## License 170 | 171 | This library is released under MIT/Apache dual license. 172 | -------------------------------------------------------------------------------- /examples/bench.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | use futures::stream; 5 | use futures::StreamExt; 6 | use pgwire::api::NoopErrorHandler; 7 | use pgwire::api::PgWireServerHandlers; 8 | use tokio::net::TcpListener; 9 | 10 | use pgwire::api::auth::noop::NoopStartupHandler; 11 | use pgwire::api::copy::NoopCopyHandler; 12 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 13 | use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response}; 14 | use pgwire::api::{ClientInfo, Type}; 15 | use pgwire::error::PgWireResult; 16 | use pgwire::tokio::process_socket; 17 | 18 | pub struct DummyProcessor; 19 | 20 | impl NoopStartupHandler for DummyProcessor {} 21 | 22 | #[async_trait] 23 | impl SimpleQueryHandler for DummyProcessor { 24 | async fn do_query<'a, C>( 25 | &self, 26 | _client: &mut C, 27 | _query: &str, 28 | ) -> PgWireResult>> 29 | where 30 | C: ClientInfo + Unpin + Send + Sync, 31 | { 32 | let f1 = FieldInfo::new("?column?".into(), None, None, Type::INT4, FieldFormat::Text); 33 | let f2 = FieldInfo::new("?column?".into(), None, None, Type::INT4, FieldFormat::Text); 34 | let f3 = FieldInfo::new("?column?".into(), None, None, Type::INT4, FieldFormat::Text); 35 | let f4 = FieldInfo::new( 36 | "?column?".into(), 37 | None, 38 | None, 39 | Type::TIMESTAMP, 40 | FieldFormat::Text, 41 | ); 42 | let f5 = FieldInfo::new( 43 | "?column?".into(), 44 | None, 45 | None, 46 | Type::FLOAT8, 47 | FieldFormat::Text, 48 | ); 49 | let f6 = FieldInfo::new("?column?".into(), None, None, Type::TEXT, FieldFormat::Text); 50 | let schema = Arc::new(vec![f1, f2, f3, f4, f5, f6]); 51 | 52 | let schema_ref = schema.clone(); 53 | 54 | let data_row_stream = stream::iter(0..5000).map(move |n| { 55 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 56 | encoder.encode_field(&n).unwrap(); 57 | encoder.encode_field(&n).unwrap(); 58 | encoder.encode_field(&n).unwrap(); 59 | encoder.encode_field(&"2004-10-19 10:23:54+02").unwrap(); 60 | encoder.encode_field(&42.0f64).unwrap(); 61 | encoder.encode_field(&"This method splits the slice into three distinct slices: prefix, correctly aligned middle slice of a new type, and the suffix slice. How exactly the slice is split up is not specified; the middle part may be smaller than necessary. However, if this fails to return a maximal middle part, that is because code is running in a context where performance does not matter, such as a sanitizer attempting to find alignment bugs. Regular code running in a default (debug or release) execution will return a maximal middle part.").unwrap(); 62 | 63 | encoder.finish() 64 | }); 65 | 66 | Ok(vec![Response::Query(QueryResponse::new( 67 | schema, 68 | data_row_stream, 69 | ))]) 70 | } 71 | } 72 | 73 | struct DummyProcessorFactory { 74 | handler: Arc, 75 | } 76 | 77 | impl PgWireServerHandlers for DummyProcessorFactory { 78 | type StartupHandler = DummyProcessor; 79 | type SimpleQueryHandler = DummyProcessor; 80 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 81 | type CopyHandler = NoopCopyHandler; 82 | type ErrorHandler = NoopErrorHandler; 83 | 84 | fn simple_query_handler(&self) -> Arc { 85 | self.handler.clone() 86 | } 87 | 88 | fn extended_query_handler(&self) -> Arc { 89 | Arc::new(PlaceholderExtendedQueryHandler) 90 | } 91 | 92 | fn startup_handler(&self) -> Arc { 93 | self.handler.clone() 94 | } 95 | 96 | fn copy_handler(&self) -> Arc { 97 | Arc::new(NoopCopyHandler) 98 | } 99 | 100 | fn error_handler(&self) -> Arc { 101 | Arc::new(NoopErrorHandler) 102 | } 103 | } 104 | 105 | #[tokio::main(flavor = "multi_thread", worker_threads = 10)] 106 | pub async fn main() { 107 | let factory = Arc::new(DummyProcessorFactory { 108 | handler: Arc::new(DummyProcessor), 109 | }); 110 | 111 | let server_addr = "127.0.0.1:5433"; 112 | let listener = TcpListener::bind(server_addr).await.unwrap(); 113 | println!("Listening to {}", server_addr); 114 | loop { 115 | let incoming_socket = listener.accept().await.unwrap(); 116 | let factory_ref = factory.clone(); 117 | 118 | tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /examples/client.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use pgwire::api::client::auth::DefaultStartupHandler; 4 | use pgwire::api::client::query::DefaultSimpleQueryHandler; 5 | use pgwire::api::client::ClientInfo; 6 | use pgwire::tokio::client::PgWireClient; 7 | 8 | #[tokio::main] 9 | pub async fn main() { 10 | let config = Arc::new( 11 | "host=127.0.0.1 port=5432 user=pgwire dbname=demo password=pencil" 12 | .parse() 13 | .unwrap(), 14 | ); 15 | let startup_handler = DefaultStartupHandler::new(); 16 | let mut client = PgWireClient::connect(config, startup_handler, None) 17 | .await 18 | .unwrap(); 19 | 20 | println!("{:?}", client.server_parameters()); 21 | 22 | let simple_query_handler = DefaultSimpleQueryHandler::new(); 23 | let result = client 24 | .simple_query(simple_query_handler, "SELECT 1") 25 | .await 26 | .unwrap() 27 | .remove(0); 28 | 29 | let mut reader = result.into_data_rows_reader(); 30 | loop { 31 | if let Some(mut row) = reader.next_row() { 32 | let value = row.next_value::(); 33 | println!("{:?}", value); 34 | } else { 35 | break; 36 | }; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/copy.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use futures::{Sink, SinkExt}; 6 | use tokio::net::TcpListener; 7 | 8 | use pgwire::api::auth::noop::NoopStartupHandler; 9 | use pgwire::api::copy::CopyHandler; 10 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 11 | use pgwire::api::results::{CopyResponse, Response}; 12 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireConnectionState, PgWireServerHandlers}; 13 | use pgwire::error::ErrorInfo; 14 | use pgwire::error::{PgWireError, PgWireResult}; 15 | use pgwire::messages::copy::{CopyData, CopyDone, CopyFail}; 16 | use pgwire::messages::response::NoticeResponse; 17 | use pgwire::messages::PgWireBackendMessage; 18 | use pgwire::tokio::process_socket; 19 | 20 | pub struct DummyProcessor; 21 | 22 | impl NoopStartupHandler for DummyProcessor {} 23 | 24 | #[async_trait] 25 | impl SimpleQueryHandler for DummyProcessor { 26 | async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult>> 27 | where 28 | C: ClientInfo + Sink + Unpin + Send + Sync, 29 | C::Error: Debug, 30 | PgWireError: From<>::Error>, 31 | { 32 | client 33 | .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( 34 | ErrorInfo::new( 35 | "NOTICE".to_owned(), 36 | "01000".to_owned(), 37 | format!("Query received {}", query), 38 | ), 39 | ))) 40 | .await?; 41 | 42 | Ok(vec![Response::CopyIn(CopyResponse::new(0, 1, vec![0]))]) 43 | } 44 | } 45 | 46 | #[async_trait] 47 | impl CopyHandler for DummyProcessor { 48 | async fn on_copy_data(&self, client: &mut C, copy_data: CopyData) -> PgWireResult<()> 49 | where 50 | C: ClientInfo + Sink + Unpin + Send + Sync, 51 | C::Error: Debug, 52 | PgWireError: From<>::Error>, 53 | { 54 | use PgWireConnectionState::*; 55 | // This is set by the `on_query` implementations while handling a 56 | // `CopyIn`/`CopyOut`/`CopyBoth` response. 57 | assert!(matches!(client.state(), CopyInProgress(_))); 58 | 59 | println!("receiving data: {:?}", copy_data); 60 | 61 | Ok(()) 62 | } 63 | 64 | async fn on_copy_done(&self, client: &mut C, _done: CopyDone) -> PgWireResult<()> 65 | where 66 | C: ClientInfo + Sink + Unpin + Send + Sync, 67 | C::Error: Debug, 68 | PgWireError: From<>::Error>, 69 | { 70 | use PgWireConnectionState::*; 71 | // This is set by the `on_query` implementations while handling a 72 | // `CopyIn`/`CopyOut`/`CopyBoth` response. 73 | assert!(matches!(client.state(), CopyInProgress(_))); 74 | 75 | println!("copy done"); 76 | 77 | Ok(()) 78 | } 79 | 80 | async fn on_copy_fail(&self, client: &mut C, fail: CopyFail) -> PgWireError 81 | where 82 | C: ClientInfo + Sink + Unpin + Send + Sync, 83 | C::Error: Debug, 84 | PgWireError: From<>::Error>, 85 | { 86 | use PgWireConnectionState::*; 87 | // This is set by the `on_query` implementations while handling a 88 | // `CopyIn`/`CopyOut`/`CopyBoth` response. 89 | assert!(matches!(client.state(), CopyInProgress(_))); 90 | 91 | println!("copy failed: {:?}", fail); 92 | 93 | PgWireError::UserError(Box::new(ErrorInfo::new( 94 | "ERROR".to_owned(), 95 | "XX000".to_owned(), 96 | format!("COPY IN mode terminated by the user: {}", fail.message), 97 | ))) 98 | } 99 | } 100 | 101 | struct DummyProcessorFactory { 102 | handler: Arc, 103 | } 104 | 105 | impl PgWireServerHandlers for DummyProcessorFactory { 106 | type StartupHandler = DummyProcessor; 107 | type SimpleQueryHandler = DummyProcessor; 108 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 109 | type CopyHandler = DummyProcessor; 110 | type ErrorHandler = NoopErrorHandler; 111 | 112 | fn simple_query_handler(&self) -> Arc { 113 | self.handler.clone() 114 | } 115 | 116 | fn extended_query_handler(&self) -> Arc { 117 | Arc::new(PlaceholderExtendedQueryHandler) 118 | } 119 | 120 | fn startup_handler(&self) -> Arc { 121 | self.handler.clone() 122 | } 123 | 124 | fn copy_handler(&self) -> Arc { 125 | self.handler.clone() 126 | } 127 | 128 | fn error_handler(&self) -> Arc { 129 | Arc::new(NoopErrorHandler) 130 | } 131 | } 132 | 133 | #[tokio::main] 134 | pub async fn main() { 135 | let factory = Arc::new(DummyProcessorFactory { 136 | handler: Arc::new(DummyProcessor), 137 | }); 138 | 139 | let server_addr = "127.0.0.1:5432"; 140 | let listener = TcpListener::bind(server_addr).await.unwrap(); 141 | println!("Listening to {}", server_addr); 142 | loop { 143 | let incoming_socket = listener.accept().await.unwrap(); 144 | let factory_ref = factory.clone(); 145 | tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /examples/gluesql.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, Mutex}; 2 | 3 | use async_trait::async_trait; 4 | use futures::stream; 5 | use tokio::net::TcpListener; 6 | 7 | use gluesql::prelude::*; 8 | use pgwire::api::auth::noop::NoopStartupHandler; 9 | use pgwire::api::copy::NoopCopyHandler; 10 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 11 | use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; 12 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; 13 | use pgwire::error::{PgWireError, PgWireResult}; 14 | use pgwire::tokio::process_socket; 15 | 16 | pub struct GluesqlProcessor { 17 | glue: Arc>>, 18 | } 19 | 20 | impl NoopStartupHandler for GluesqlProcessor {} 21 | 22 | #[async_trait] 23 | impl SimpleQueryHandler for GluesqlProcessor { 24 | async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> 25 | where 26 | C: ClientInfo + Unpin + Send + Sync, 27 | { 28 | println!("{:?}", query); 29 | let mut glue = self.glue.lock().unwrap(); 30 | futures::executor::block_on(glue.execute(query)) 31 | .map_err(|err| PgWireError::ApiError(Box::new(err))) 32 | .and_then(|payloads| { 33 | payloads 34 | .iter() 35 | .map(|payload| match payload { 36 | Payload::Select { labels, rows } => { 37 | let fields = labels 38 | .iter() 39 | .map(|label| { 40 | FieldInfo::new( 41 | label.into(), 42 | None, 43 | None, 44 | Type::UNKNOWN, 45 | FieldFormat::Text, 46 | ) 47 | }) 48 | .collect::>(); 49 | let fields = Arc::new(fields); 50 | 51 | let mut results = Vec::with_capacity(rows.len()); 52 | for row in rows { 53 | let mut encoder = DataRowEncoder::new(fields.clone()); 54 | for field in row.iter() { 55 | match field { 56 | Value::Bool(v) => encoder 57 | .encode_field_with_type_and_format( 58 | v, 59 | &Type::BOOL, 60 | FieldFormat::Text, 61 | )?, 62 | Value::I8(v) => encoder.encode_field_with_type_and_format( 63 | v, 64 | &Type::CHAR, 65 | FieldFormat::Text, 66 | )?, 67 | Value::I16(v) => encoder 68 | .encode_field_with_type_and_format( 69 | v, 70 | &Type::INT2, 71 | FieldFormat::Text, 72 | )?, 73 | Value::I32(v) => encoder 74 | .encode_field_with_type_and_format( 75 | v, 76 | &Type::INT4, 77 | FieldFormat::Text, 78 | )?, 79 | Value::I64(v) => encoder 80 | .encode_field_with_type_and_format( 81 | v, 82 | &Type::INT8, 83 | FieldFormat::Text, 84 | )?, 85 | Value::U8(v) => encoder.encode_field_with_type_and_format( 86 | &(*v as i8), 87 | &Type::CHAR, 88 | FieldFormat::Text, 89 | )?, 90 | Value::F64(v) => encoder 91 | .encode_field_with_type_and_format( 92 | v, 93 | &Type::FLOAT8, 94 | FieldFormat::Text, 95 | )?, 96 | Value::Str(v) => encoder 97 | .encode_field_with_type_and_format( 98 | v, 99 | &Type::VARCHAR, 100 | FieldFormat::Text, 101 | )?, 102 | Value::Bytea(v) => encoder 103 | .encode_field_with_type_and_format( 104 | v, 105 | &Type::BYTEA, 106 | FieldFormat::Text, 107 | )?, 108 | Value::Date(v) => encoder 109 | .encode_field_with_type_and_format( 110 | v, 111 | &Type::DATE, 112 | FieldFormat::Text, 113 | )?, 114 | Value::Time(v) => encoder 115 | .encode_field_with_type_and_format( 116 | v, 117 | &Type::TIME, 118 | FieldFormat::Text, 119 | )?, 120 | Value::Timestamp(v) => encoder 121 | .encode_field_with_type_and_format( 122 | v, 123 | &Type::TIMESTAMP, 124 | FieldFormat::Text, 125 | )?, 126 | _ => unimplemented!(), 127 | } 128 | } 129 | results.push(encoder.finish()); 130 | } 131 | 132 | Ok(Response::Query(QueryResponse::new( 133 | fields, 134 | stream::iter(results.into_iter()), 135 | ))) 136 | } 137 | Payload::Insert(rows) => Ok(Response::Execution( 138 | Tag::new("INSERT").with_oid(0).with_rows(*rows), 139 | )), 140 | Payload::Delete(rows) => { 141 | Ok(Response::Execution(Tag::new("DELETE").with_rows(*rows))) 142 | } 143 | Payload::Update(rows) => { 144 | Ok(Response::Execution(Tag::new("UPDATE").with_rows(*rows))) 145 | } 146 | Payload::Create => Ok(Response::Execution(Tag::new("CREATE TABLE"))), 147 | Payload::AlterTable => Ok(Response::Execution(Tag::new("ALTER TABLE"))), 148 | Payload::DropTable(_) => Ok(Response::Execution(Tag::new("DROP TABLE"))), 149 | Payload::CreateIndex => Ok(Response::Execution(Tag::new("CREATE INDEX"))), 150 | Payload::DropIndex => Ok(Response::Execution(Tag::new("DROP INDEX"))), 151 | _ => { 152 | unimplemented!() 153 | } 154 | }) 155 | .collect::, PgWireError>>() 156 | }) 157 | } 158 | } 159 | 160 | struct GluesqlHandlerFactory { 161 | processor: Arc, 162 | } 163 | 164 | impl PgWireServerHandlers for GluesqlHandlerFactory { 165 | type StartupHandler = GluesqlProcessor; 166 | type SimpleQueryHandler = GluesqlProcessor; 167 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 168 | type CopyHandler = NoopCopyHandler; 169 | type ErrorHandler = NoopErrorHandler; 170 | 171 | fn simple_query_handler(&self) -> Arc { 172 | self.processor.clone() 173 | } 174 | 175 | fn extended_query_handler(&self) -> Arc { 176 | Arc::new(PlaceholderExtendedQueryHandler) 177 | } 178 | 179 | fn startup_handler(&self) -> Arc { 180 | self.processor.clone() 181 | } 182 | 183 | fn copy_handler(&self) -> Arc { 184 | Arc::new(NoopCopyHandler) 185 | } 186 | 187 | fn error_handler(&self) -> Arc { 188 | Arc::new(NoopErrorHandler) 189 | } 190 | } 191 | 192 | #[tokio::main] 193 | pub async fn main() { 194 | let gluesql = GluesqlProcessor { 195 | glue: Arc::new(Mutex::new(Glue::new(MemoryStorage::default()))), 196 | }; 197 | 198 | let factory = Arc::new(GluesqlHandlerFactory { 199 | processor: Arc::new(gluesql), 200 | }); 201 | 202 | let server_addr = "127.0.0.1:5432"; 203 | let listener = TcpListener::bind(server_addr).await.unwrap(); 204 | println!("Listening to {}", server_addr); 205 | loop { 206 | let incoming_socket = listener.accept().await.unwrap(); 207 | let factory_ref = factory.clone(); 208 | 209 | tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /examples/scram.rs: -------------------------------------------------------------------------------- 1 | use std::fs::{self, File}; 2 | use std::io::{BufReader, Error as IOError, ErrorKind}; 3 | use std::sync::Arc; 4 | 5 | use async_trait::async_trait; 6 | 7 | use rustls_pemfile::{certs, pkcs8_private_keys}; 8 | use rustls_pki_types::{CertificateDer, PrivateKeyDer}; 9 | use tokio::net::TcpListener; 10 | use tokio_rustls::rustls::ServerConfig; 11 | use tokio_rustls::TlsAcceptor; 12 | 13 | use pgwire::api::auth::scram::{gen_salted_password, SASLScramAuthStartupHandler}; 14 | use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; 15 | use pgwire::api::copy::NoopCopyHandler; 16 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 17 | use pgwire::api::results::{Response, Tag}; 18 | 19 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers}; 20 | use pgwire::error::PgWireResult; 21 | use pgwire::tokio::process_socket; 22 | 23 | pub struct DummyProcessor; 24 | 25 | #[async_trait] 26 | impl SimpleQueryHandler for DummyProcessor { 27 | async fn do_query<'a, C>( 28 | &self, 29 | _client: &mut C, 30 | _query: &str, 31 | ) -> PgWireResult>> 32 | where 33 | C: ClientInfo + Unpin + Send + Sync, 34 | { 35 | Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))]) 36 | } 37 | } 38 | 39 | pub fn random_salt() -> Vec { 40 | Vec::from(rand::random::<[u8; 10]>()) 41 | } 42 | 43 | const ITERATIONS: usize = 4096; 44 | 45 | struct DummyAuthDB; 46 | 47 | #[async_trait] 48 | impl AuthSource for DummyAuthDB { 49 | async fn get_password(&self, _login: &LoginInfo) -> PgWireResult { 50 | let password = "pencil"; 51 | let salt = random_salt(); 52 | 53 | let hash_password = gen_salted_password(password, salt.as_ref(), ITERATIONS); 54 | Ok(Password::new(Some(salt), hash_password)) 55 | } 56 | } 57 | 58 | /// configure TlsAcceptor and get server cert for SCRAM channel binding 59 | fn setup_tls() -> Result { 60 | let cert = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?)) 61 | .collect::, IOError>>()?; 62 | 63 | let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?)) 64 | .map(|key| key.map(PrivateKeyDer::from)) 65 | .collect::, IOError>>()? 66 | .remove(0); 67 | 68 | let config = ServerConfig::builder() 69 | .with_no_client_auth() 70 | .with_single_cert(cert, key) 71 | .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?; 72 | 73 | Ok(TlsAcceptor::from(Arc::new(config))) 74 | } 75 | 76 | struct DummyProcessorFactory { 77 | handler: Arc, 78 | cert: Vec, 79 | } 80 | 81 | impl PgWireServerHandlers for DummyProcessorFactory { 82 | type StartupHandler = SASLScramAuthStartupHandler; 83 | type SimpleQueryHandler = DummyProcessor; 84 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 85 | type CopyHandler = NoopCopyHandler; 86 | type ErrorHandler = NoopErrorHandler; 87 | 88 | fn simple_query_handler(&self) -> Arc { 89 | self.handler.clone() 90 | } 91 | 92 | fn extended_query_handler(&self) -> Arc { 93 | Arc::new(PlaceholderExtendedQueryHandler) 94 | } 95 | 96 | fn startup_handler(&self) -> Arc { 97 | let mut authenticator = SASLScramAuthStartupHandler::new( 98 | Arc::new(DummyAuthDB), 99 | Arc::new(DefaultServerParameterProvider::default()), 100 | ); 101 | authenticator.set_iterations(ITERATIONS); 102 | authenticator 103 | .configure_certificate(self.cert.as_ref()) 104 | .unwrap(); 105 | 106 | Arc::new(authenticator) 107 | } 108 | 109 | fn copy_handler(&self) -> Arc { 110 | Arc::new(NoopCopyHandler) 111 | } 112 | 113 | fn error_handler(&self) -> Arc { 114 | Arc::new(NoopErrorHandler) 115 | } 116 | } 117 | 118 | #[tokio::main] 119 | pub async fn main() { 120 | let cert = fs::read("examples/ssl/server.crt").unwrap(); 121 | let factory = Arc::new(DummyProcessorFactory { 122 | handler: Arc::new(DummyProcessor), 123 | cert, 124 | }); 125 | 126 | let server_addr = "127.0.0.1:5432"; 127 | let tls_acceptor = setup_tls().unwrap(); 128 | let listener = TcpListener::bind(server_addr).await.unwrap(); 129 | println!("Listening to {}", server_addr); 130 | loop { 131 | let incoming_socket = listener.accept().await.unwrap(); 132 | let tls_acceptor_ref = tls_acceptor.clone(); 133 | 134 | let factory_ref = factory.clone(); 135 | 136 | tokio::spawn(async move { 137 | process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await 138 | }); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /examples/secure_server.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufReader, Error as IOError, ErrorKind}; 3 | use std::sync::Arc; 4 | 5 | use async_trait::async_trait; 6 | use futures::{stream, StreamExt}; 7 | use rustls_pemfile::{certs, pkcs8_private_keys}; 8 | use rustls_pki_types::{CertificateDer, PrivateKeyDer}; 9 | use tokio::net::TcpListener; 10 | use tokio_rustls::rustls::ServerConfig; 11 | use tokio_rustls::TlsAcceptor; 12 | 13 | use pgwire::api::auth::noop::NoopStartupHandler; 14 | use pgwire::api::copy::NoopCopyHandler; 15 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 16 | use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; 17 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; 18 | use pgwire::error::PgWireResult; 19 | use pgwire::tokio::process_socket; 20 | 21 | pub struct DummyProcessor; 22 | 23 | impl NoopStartupHandler for DummyProcessor {} 24 | 25 | #[async_trait] 26 | impl SimpleQueryHandler for DummyProcessor { 27 | async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> 28 | where 29 | C: ClientInfo + Unpin + Send + Sync, 30 | { 31 | println!("{:?}", query); 32 | if query.starts_with("SELECT") { 33 | let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text); 34 | let f2 = FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text); 35 | let schema = Arc::new(vec![f1, f2]); 36 | 37 | let data = vec![ 38 | (Some(0), Some("Tom")), 39 | (Some(1), Some("Jerry")), 40 | (Some(2), None), 41 | ]; 42 | let schema_ref = schema.clone(); 43 | let data_row_stream = stream::iter(data.into_iter()).map(move |r| { 44 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 45 | encoder.encode_field(&r.0)?; 46 | encoder.encode_field(&r.1)?; 47 | 48 | encoder.finish() 49 | }); 50 | 51 | Ok(vec![Response::Query(QueryResponse::new( 52 | schema, 53 | data_row_stream, 54 | ))]) 55 | } else { 56 | Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))]) 57 | } 58 | } 59 | } 60 | 61 | fn setup_tls() -> Result { 62 | let cert = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?)) 63 | .collect::, IOError>>()?; 64 | 65 | let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?)) 66 | .map(|key| key.map(PrivateKeyDer::from)) 67 | .collect::, IOError>>()? 68 | .remove(0); 69 | 70 | let mut config = ServerConfig::builder() 71 | .with_no_client_auth() 72 | .with_single_cert(cert, key) 73 | .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?; 74 | 75 | config.alpn_protocols = vec![b"postgresql".to_vec()]; 76 | 77 | Ok(TlsAcceptor::from(Arc::new(config))) 78 | } 79 | 80 | struct DummyProcessorFactory { 81 | handler: Arc, 82 | } 83 | 84 | impl PgWireServerHandlers for DummyProcessorFactory { 85 | type StartupHandler = DummyProcessor; 86 | type SimpleQueryHandler = DummyProcessor; 87 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 88 | type CopyHandler = NoopCopyHandler; 89 | type ErrorHandler = NoopErrorHandler; 90 | 91 | fn simple_query_handler(&self) -> Arc { 92 | self.handler.clone() 93 | } 94 | 95 | fn extended_query_handler(&self) -> Arc { 96 | Arc::new(PlaceholderExtendedQueryHandler) 97 | } 98 | 99 | fn startup_handler(&self) -> Arc { 100 | self.handler.clone() 101 | } 102 | 103 | fn copy_handler(&self) -> Arc { 104 | Arc::new(NoopCopyHandler) 105 | } 106 | 107 | fn error_handler(&self) -> Arc { 108 | Arc::new(NoopErrorHandler) 109 | } 110 | } 111 | 112 | #[tokio::main] 113 | pub async fn main() { 114 | let factory = Arc::new(DummyProcessorFactory { 115 | handler: Arc::new(DummyProcessor), 116 | }); 117 | 118 | let server_addr = "127.0.0.1:5433"; 119 | let tls_acceptor = setup_tls().unwrap(); 120 | let listener = TcpListener::bind(server_addr).await.unwrap(); 121 | 122 | println!("Listening to {}", server_addr); 123 | loop { 124 | let incoming_socket = listener.accept().await.unwrap(); 125 | let tls_acceptor_ref = tls_acceptor.clone(); 126 | let factory_ref = factory.clone(); 127 | tokio::spawn(async move { 128 | process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await 129 | }); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /examples/server.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use futures::{stream, Sink, SinkExt, StreamExt}; 6 | use tokio::net::TcpListener; 7 | 8 | use pgwire::api::auth::noop::NoopStartupHandler; 9 | use pgwire::api::copy::NoopCopyHandler; 10 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 11 | use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; 12 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; 13 | use pgwire::error::ErrorInfo; 14 | use pgwire::error::{PgWireError, PgWireResult}; 15 | use pgwire::messages::response::NoticeResponse; 16 | use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 17 | use pgwire::tokio::process_socket; 18 | 19 | pub struct DummyProcessor; 20 | 21 | #[async_trait] 22 | impl NoopStartupHandler for DummyProcessor { 23 | async fn post_startup( 24 | &self, 25 | client: &mut C, 26 | _message: PgWireFrontendMessage, 27 | ) -> PgWireResult<()> 28 | where 29 | C: ClientInfo + Sink + Unpin + Send, 30 | C::Error: Debug, 31 | PgWireError: From<>::Error>, 32 | { 33 | println!( 34 | "connected {:?}: {:?}", 35 | client.socket_addr(), 36 | client.metadata() 37 | ); 38 | Ok(()) 39 | } 40 | } 41 | 42 | #[async_trait] 43 | impl SimpleQueryHandler for DummyProcessor { 44 | async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult>> 45 | where 46 | C: ClientInfo + Sink + Unpin + Send + Sync, 47 | C::Error: Debug, 48 | PgWireError: From<>::Error>, 49 | { 50 | client 51 | .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( 52 | ErrorInfo::new( 53 | "NOTICE".to_owned(), 54 | "01000".to_owned(), 55 | format!("Query received {}", query), 56 | ), 57 | ))) 58 | .await?; 59 | 60 | if query.starts_with("SELECT") { 61 | let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text); 62 | let f2 = FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text); 63 | let schema = Arc::new(vec![f1, f2]); 64 | 65 | let data = vec![ 66 | (Some(0), Some("Tom")), 67 | (Some(1), Some("Jerry")), 68 | (Some(2), None), 69 | ]; 70 | let schema_ref = schema.clone(); 71 | let data_row_stream = stream::iter(data.into_iter()).map(move |r| { 72 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 73 | encoder.encode_field(&r.0)?; 74 | encoder.encode_field(&r.1)?; 75 | 76 | encoder.finish() 77 | }); 78 | 79 | Ok(vec![Response::Query(QueryResponse::new( 80 | schema, 81 | data_row_stream, 82 | ))]) 83 | } else { 84 | Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))]) 85 | } 86 | } 87 | } 88 | 89 | struct DummyProcessorFactory { 90 | handler: Arc, 91 | } 92 | 93 | impl PgWireServerHandlers for DummyProcessorFactory { 94 | type StartupHandler = DummyProcessor; 95 | type SimpleQueryHandler = DummyProcessor; 96 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 97 | type CopyHandler = NoopCopyHandler; 98 | type ErrorHandler = NoopErrorHandler; 99 | 100 | fn simple_query_handler(&self) -> Arc { 101 | self.handler.clone() 102 | } 103 | 104 | fn extended_query_handler(&self) -> Arc { 105 | Arc::new(PlaceholderExtendedQueryHandler) 106 | } 107 | 108 | fn startup_handler(&self) -> Arc { 109 | self.handler.clone() 110 | } 111 | 112 | fn copy_handler(&self) -> Arc { 113 | Arc::new(NoopCopyHandler) 114 | } 115 | 116 | fn error_handler(&self) -> Arc { 117 | Arc::new(NoopErrorHandler) 118 | } 119 | } 120 | 121 | #[tokio::main] 122 | pub async fn main() { 123 | let factory = Arc::new(DummyProcessorFactory { 124 | handler: Arc::new(DummyProcessor), 125 | }); 126 | 127 | let server_addr = "127.0.0.1:5432"; 128 | let listener = TcpListener::bind(server_addr).await.unwrap(); 129 | println!("Listening to {}", server_addr); 130 | loop { 131 | let incoming_socket = listener.accept().await.unwrap(); 132 | let factory_ref = factory.clone(); 133 | tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /examples/ssl/server.crt: -------------------------------------------------------------------------------- 1 | Certificate: 2 | Data: 3 | Version: 3 (0x2) 4 | Serial Number: 5 | 1e:a1:44:88:27:3d:5c:c8:ff:ef:06:2e:da:21:05:29:30:a5:ce:2c 6 | Signature Algorithm: sha256WithRSAEncryption 7 | Issuer: CN = localhost 8 | Validity 9 | Not Before: Oct 11 07:36:01 2022 GMT 10 | Not After : Oct 8 07:36:01 2032 GMT 11 | Subject: CN = localhost 12 | Subject Public Key Info: 13 | Public Key Algorithm: rsaEncryption 14 | RSA Public-Key: (2048 bit) 15 | Modulus: 16 | 00:d5:b0:29:38:63:13:5e:1e:1d:ae:1f:47:88:b4: 17 | 44:96:21:d8:d7:03:a3:d8:f9:03:2f:4e:79:66:e6: 18 | db:19:55:1d:85:9b:f1:78:2d:87:f3:72:91:13:dc: 19 | ff:00:cb:ab:fd:a1:c8:3a:56:26:e3:88:1d:ec:98: 20 | 4a:af:eb:f9:60:80:27:e1:06:ba:c0:0d:c3:09:0e: 21 | fe:d8:86:1e:25:b4:04:62:a5:75:46:8e:11:e8:61: 22 | 59:aa:97:17:ea:c7:4c:c6:13:8c:6d:54:2a:b9:78: 23 | 86:54:a9:6f:d6:31:96:c6:41:76:a3:c7:67:40:6f: 24 | f2:1a:4c:0d:77:05:bb:3d:0b:16:f8:c7:de:6c:de: 25 | 7b:2e:b6:29:85:4b:a8:36:d3:f2:84:75:e0:85:17: 26 | ce:22:84:4b:94:02:17:8a:36:2b:13:ee:2f:aa:55: 27 | 6b:ff:8b:df:d3:e0:23:8d:fd:c3:f8:e2:c8:a7:d5: 28 | 76:a6:73:7d:a8:5f:6a:49:02:78:a2:c5:66:14:ee: 29 | 86:50:3b:d1:67:7f:1b:0c:27:0d:84:ec:44:0d:39: 30 | 08:ba:69:65:e0:35:a4:67:aa:19:e7:fe:0e:4b:9f: 31 | 23:1e:4e:38:ed:d7:93:57:6e:94:31:05:d3:ae:f7: 32 | 6c:01:3c:30:69:19:f4:7b:b5:48:95:71:c9:9c:30: 33 | 43:9d 34 | Exponent: 65537 (0x10001) 35 | X509v3 extensions: 36 | X509v3 Subject Key Identifier: 37 | 8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31 38 | X509v3 Authority Key Identifier: 39 | keyid:8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31 40 | 41 | X509v3 Basic Constraints: critical 42 | CA:TRUE 43 | Signature Algorithm: sha256WithRSAEncryption 44 | 6c:ae:ee:3e:e3:d4:5d:29:37:62:b0:32:ce:a4:36:c7:25:b4: 45 | 6a:9f:ba:b4:f0:2f:0a:96:2f:dc:6d:df:7d:92:e7:f0:ee:f7: 46 | de:44:9d:52:36:ff:0c:98:ef:8b:7f:27:df:6e:fe:64:11:7c: 47 | 01:5d:7f:c8:73:a3:24:24:ba:81:fd:a8:ae:28:4f:93:bb:92: 48 | ff:86:d6:48:a2:ca:a5:1f:ea:1c:0d:02:22:e8:71:23:27:22: 49 | 4f:0f:37:58:9a:d9:fd:70:c5:4c:93:7d:47:1c:b6:ea:1b:4f: 50 | 4e:7c:eb:9d:9a:d3:28:78:67:27:e9:b1:ea:f6:93:68:76:e5: 51 | 2e:52:c6:29:91:ba:0a:96:2e:14:33:69:35:d7:b5:e0:c0:ef: 52 | 05:77:09:9b:a1:cc:7b:b2:f0:6a:cb:5c:5f:a1:27:69:b0:2c: 53 | 6e:93:eb:37:98:cd:97:8d:9e:78:a8:f5:99:12:66:86:48:cf: 54 | b2:e0:68:6f:77:98:06:13:24:55:d1:c3:80:1d:59:53:1f:44: 55 | 85:bc:5d:29:aa:2a:a1:06:17:6b:e7:2b:11:0b:fd:e3:f8:88: 56 | 89:32:57:a3:70:f7:1b:6c:c1:66:c7:3c:a4:2d:e8:5f:00:1c: 57 | 55:2f:72:ed:d4:3a:3f:d0:95:de:6c:a4:96:6e:b4:63:0e:80: 58 | 08:b2:25:d5 59 | -----BEGIN CERTIFICATE----- 60 | MIIDCTCCAfGgAwIBAgIUHqFEiCc9XMj/7wYu2iEFKTClziwwDQYJKoZIhvcNAQEL 61 | BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMTAxMTA3MzYwMVoXDTMyMTAw 62 | ODA3MzYwMVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF 63 | AAOCAQ8AMIIBCgKCAQEA1bApOGMTXh4drh9HiLREliHY1wOj2PkDL055ZubbGVUd 64 | hZvxeC2H83KRE9z/AMur/aHIOlYm44gd7JhKr+v5YIAn4Qa6wA3DCQ7+2IYeJbQE 65 | YqV1Ro4R6GFZqpcX6sdMxhOMbVQquXiGVKlv1jGWxkF2o8dnQG/yGkwNdwW7PQsW 66 | +MfebN57LrYphUuoNtPyhHXghRfOIoRLlAIXijYrE+4vqlVr/4vf0+Ajjf3D+OLI 67 | p9V2pnN9qF9qSQJ4osVmFO6GUDvRZ38bDCcNhOxEDTkIumll4DWkZ6oZ5/4OS58j 68 | Hk447deTV26UMQXTrvdsATwwaRn0e7VIlXHJnDBDnQIDAQABo1MwUTAdBgNVHQ4E 69 | FgQUjoELYLH5fdhkkbswhuU9zbeC2DEwHwYDVR0jBBgwFoAUjoELYLH5fdhkkbsw 70 | huU9zbeC2DEwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbK7u 71 | PuPUXSk3YrAyzqQ2xyW0ap+6tPAvCpYv3G3ffZLn8O733kSdUjb/DJjvi38n327+ 72 | ZBF8AV1/yHOjJCS6gf2orihPk7uS/4bWSKLKpR/qHA0CIuhxIyciTw83WJrZ/XDF 73 | TJN9Rxy26htPTnzrnZrTKHhnJ+mx6vaTaHblLlLGKZG6CpYuFDNpNde14MDvBXcJ 74 | m6HMe7LwastcX6EnabAsbpPrN5jNl42eeKj1mRJmhkjPsuBob3eYBhMkVdHDgB1Z 75 | Ux9EhbxdKaoqoQYXa+crEQv94/iIiTJXo3D3G2zBZsc8pC3oXwAcVS9y7dQ6P9CV 76 | 3myklm60Yw6ACLIl1Q== 77 | -----END CERTIFICATE----- 78 | -------------------------------------------------------------------------------- /examples/ssl/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDVsCk4YxNeHh2u 3 | H0eItESWIdjXA6PY+QMvTnlm5tsZVR2Fm/F4LYfzcpET3P8Ay6v9ocg6VibjiB3s 4 | mEqv6/lggCfhBrrADcMJDv7Yhh4ltARipXVGjhHoYVmqlxfqx0zGE4xtVCq5eIZU 5 | qW/WMZbGQXajx2dAb/IaTA13Bbs9Cxb4x95s3nsutimFS6g20/KEdeCFF84ihEuU 6 | AheKNisT7i+qVWv/i9/T4CON/cP44sin1Xamc32oX2pJAniixWYU7oZQO9FnfxsM 7 | Jw2E7EQNOQi6aWXgNaRnqhnn/g5LnyMeTjjt15NXbpQxBdOu92wBPDBpGfR7tUiV 8 | ccmcMEOdAgMBAAECggEBAMMCIJv0zpf1o+Bja0S2PmFEQj72c3Buzxk85E2kIA7e 9 | PjLQPW0PICJrSzp1U8HGHQ85tSCHvrWmYqin0oD5OHt4eOxC1+qspHB/3tJ6ksiV 10 | n+rmVEAvJuiK7ulfOdRoTQf2jxC23saj1vMsLYOrfY0v8LVGJFQJ1UdqYF9eO6FX 11 | 8i6eQekV0n8u+DMUysYXfePDXEwpunKrlZwZtThgBY31gAIOdNo/FOAFe1yBJdPl 12 | rUFZes1IrE0c4CNxodajuRNCjtNWoX8TK1cXQVUpPprdFLBcYG2P9mPZ7SkZWJc7 13 | rkyPX6Wkb7q3laUCBxuKL1iOJIwaVBYaKfv4HS7VuYECgYEA9H7VB8+whWx2cTFb 14 | 9oYbcaU3HtbKRh6KQP8eB4IWeKV/c/ceWVAxtU9Hx2QU1zZ2fLl+KkaOGeECNNqD 15 | BP1O5qk2qmkjJcP4kzh1K+p7zkqAkrhHqB36y/gwptB8v7JbCchQq9cnBeYsXNIa 16 | j13KvteprRSnanKu18d2aC43cNMCgYEA3746ITtqy1g6AQ0Q/MXN/axsXixKfVjf 17 | kgN/lpjy6oeoEIWKqiNrOQpwy4NeBo6ZN+cwjUUr9SY/BKsZqMGErO8Xuu+QtJYD 18 | ioW/My9rTrTElbpsLpSvZDLc9IRepV4k+5PpXTIRBqp7Q3BZnTjbRMc8x/owG23G 19 | eXnfVKlWM88CgYEA5HBQuMCrzK3/qFkW9Kpun+tfKfhD++nzATGcrCU2u7jd8cr1 20 | 1zsfhqkxhrIS6tYfNP/XSsarZLCgcCOuAQ5wFwIJaoVbaqDE80Dv8X1f+eoQYYW+ 21 | peyE9OjLBEGOHUoW13gLL9ORyWg7EOraGBPpKBC2n1nJ5qKKjF/4WPS9pjMCgYEA 22 | 3UuUyxGtivn0RN3bk2dBWkmT1YERG/EvD4gORbF5caZDADRU9fqaLoy5C1EfSnT3 23 | 7mbnipKD67CsW72vX04oH7NLUUVpZnOJhRTMC6A3Dl2UolMEdP3yi7QS/nV99ymq 24 | gnnFMrw2QtWTnRweRnbZyKkW4OP/eOGWkMeNsHrcG9kCgYEAz/09cKumk349AIXV 25 | g6Jw64gCTjWh157wnD3ZSPPEcr/09/fZwf1W0gkY/tbCVrVPJHWb3K5t2nRXjLlz 26 | HMnQXmcMxMlY3Ufvm2H3ov1ODPKwpcBWUZqnpFTZX7rC58lO/wvgiKpgtHA3pDdw 27 | oYDaaozVP4EnnByxhmHaM7ce07U= 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /examples/transaction.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use futures::{stream, Sink, SinkExt}; 6 | use tokio::net::TcpListener; 7 | 8 | use pgwire::api::auth::noop::NoopStartupHandler; 9 | use pgwire::api::copy::NoopCopyHandler; 10 | use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; 11 | use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; 12 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; 13 | use pgwire::error::ErrorInfo; 14 | use pgwire::error::{PgWireError, PgWireResult}; 15 | use pgwire::messages::response::NoticeResponse; 16 | use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 17 | use pgwire::tokio::process_socket; 18 | 19 | pub struct DummyProcessor; 20 | 21 | #[async_trait] 22 | impl NoopStartupHandler for DummyProcessor { 23 | async fn post_startup( 24 | &self, 25 | client: &mut C, 26 | _message: PgWireFrontendMessage, 27 | ) -> PgWireResult<()> 28 | where 29 | C: ClientInfo + Sink + Unpin + Send, 30 | C::Error: Debug, 31 | PgWireError: From<>::Error>, 32 | { 33 | println!("Connected: {}", client.socket_addr()); 34 | client 35 | .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( 36 | ErrorInfo::new( 37 | "NOTICE".to_owned(), 38 | "01000".to_owned(), 39 | "Supported queries in this example:\n- BEGIN;\n- ROLLBACK;\n- COMMIT;\n- SELECT 1;" 40 | .to_string(), 41 | ), 42 | ))) 43 | .await?; 44 | Ok(()) 45 | } 46 | } 47 | 48 | #[async_trait] 49 | impl SimpleQueryHandler for DummyProcessor { 50 | async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> 51 | where 52 | C: ClientInfo + Sink + Unpin + Send + Sync, 53 | C::Error: Debug, 54 | PgWireError: From<>::Error>, 55 | { 56 | let resp = match query { 57 | "BEGIN;" => Response::TransactionStart(Tag::new("BEGIN")), 58 | "ROLLBACK;" => Response::TransactionEnd(Tag::new("ROLLBACK")), 59 | "COMMIT;" => Response::TransactionEnd(Tag::new("COMMIT")), 60 | "SELECT 1;" => { 61 | let f1 = 62 | FieldInfo::new("SELECT 1".into(), None, None, Type::INT4, FieldFormat::Text); 63 | let schema = Arc::new(vec![f1]); 64 | let schema_ref = schema.clone(); 65 | 66 | let row = { 67 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 68 | encoder.encode_field(&Some(1))?; 69 | 70 | encoder.finish() 71 | }; 72 | let data_row_stream = stream::iter(vec![row]); 73 | Response::Query(QueryResponse::new(schema, data_row_stream)) 74 | } 75 | _ => Response::Error(Box::new(ErrorInfo::new( 76 | "FATAL".to_string(), 77 | "38003".to_string(), 78 | "Unsupported statement.".to_string(), 79 | ))), 80 | }; 81 | 82 | Ok(vec![resp]) 83 | } 84 | } 85 | 86 | struct DummyProcessorFactory { 87 | handler: Arc, 88 | } 89 | 90 | impl PgWireServerHandlers for DummyProcessorFactory { 91 | type StartupHandler = DummyProcessor; 92 | type SimpleQueryHandler = DummyProcessor; 93 | type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; 94 | type CopyHandler = NoopCopyHandler; 95 | type ErrorHandler = NoopErrorHandler; 96 | 97 | fn simple_query_handler(&self) -> Arc { 98 | self.handler.clone() 99 | } 100 | 101 | fn extended_query_handler(&self) -> Arc { 102 | Arc::new(PlaceholderExtendedQueryHandler) 103 | } 104 | 105 | fn startup_handler(&self) -> Arc { 106 | self.handler.clone() 107 | } 108 | 109 | fn copy_handler(&self) -> Arc { 110 | Arc::new(NoopCopyHandler) 111 | } 112 | 113 | fn error_handler(&self) -> Arc { 114 | Arc::new(NoopErrorHandler) 115 | } 116 | } 117 | 118 | #[tokio::main] 119 | pub async fn main() { 120 | let factory = Arc::new(DummyProcessorFactory { 121 | handler: Arc::new(DummyProcessor), 122 | }); 123 | 124 | let server_addr = "127.0.0.1:5432"; 125 | let listener = TcpListener::bind(server_addr).await.unwrap(); 126 | println!("Listening to {}", server_addr); 127 | loop { 128 | let incoming_socket = listener.accept().await.unwrap(); 129 | let factory_ref = factory.clone(); 130 | tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /pgbench/bench.sh: -------------------------------------------------------------------------------- 1 | pgbench -f pgbench/select.sql -c 20 -T 30 -h 127.0.0.1 -p 5433 -U postgres -d postgres 2 | -------------------------------------------------------------------------------- /pgbench/select.sql: -------------------------------------------------------------------------------- 1 | SELECT 1; 2 | -------------------------------------------------------------------------------- /release.toml: -------------------------------------------------------------------------------- 1 | pre-release-replacements = [ 2 | {file="CHANGELOG.md", search="Unreleased", replace="{{version}}", prerelease=false}, 3 | {file="CHANGELOG.md", search="ReleaseDate", replace="{{date}}", prerelease=false} 4 | ] 5 | -------------------------------------------------------------------------------- /src/api/auth/cleartext.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use async_trait::async_trait; 4 | use futures::sink::{Sink, SinkExt}; 5 | 6 | use super::{ 7 | AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider, 8 | StartupHandler, 9 | }; 10 | use crate::error::{PgWireError, PgWireResult}; 11 | use crate::messages::startup::Authentication; 12 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 13 | 14 | #[derive(new)] 15 | pub struct CleartextPasswordAuthStartupHandler { 16 | auth_source: A, 17 | parameter_provider: P, 18 | } 19 | 20 | #[async_trait] 21 | impl StartupHandler 22 | for CleartextPasswordAuthStartupHandler 23 | { 24 | async fn on_startup( 25 | &self, 26 | client: &mut C, 27 | message: PgWireFrontendMessage, 28 | ) -> PgWireResult<()> 29 | where 30 | C: ClientInfo + Sink + Unpin + Send, 31 | C::Error: Debug, 32 | PgWireError: From<>::Error>, 33 | { 34 | match message { 35 | PgWireFrontendMessage::Startup(ref startup) => { 36 | super::save_startup_parameters_to_metadata(client, startup); 37 | client.set_state(PgWireConnectionState::AuthenticationInProgress); 38 | client 39 | .send(PgWireBackendMessage::Authentication( 40 | Authentication::CleartextPassword, 41 | )) 42 | .await?; 43 | } 44 | PgWireFrontendMessage::PasswordMessageFamily(pwd) => { 45 | let pwd = pwd.into_password()?; 46 | let login_info = LoginInfo::from_client_info(client); 47 | let pass = self.auth_source.get_password(&login_info).await?; 48 | if pass.password == pwd.password.as_bytes() { 49 | super::finish_authentication(client, &self.parameter_provider).await?; 50 | } else { 51 | return Err(PgWireError::InvalidPassword( 52 | login_info.user().map(|x| x.to_owned()).unwrap_or_default(), 53 | )); 54 | } 55 | } 56 | _ => {} 57 | } 58 | Ok(()) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/api/auth/md5pass.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use futures::sink::{Sink, SinkExt}; 6 | use tokio::sync::Mutex; 7 | 8 | use super::{ 9 | AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider, 10 | StartupHandler, 11 | }; 12 | use crate::error::{PgWireError, PgWireResult}; 13 | use crate::messages::startup::Authentication; 14 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 15 | 16 | pub struct Md5PasswordAuthStartupHandler { 17 | auth_source: Arc, 18 | parameter_provider: Arc

, 19 | cached_password: Mutex>, 20 | } 21 | 22 | impl Md5PasswordAuthStartupHandler { 23 | pub fn new(auth_source: Arc, parameter_provider: Arc

) -> Self { 24 | Md5PasswordAuthStartupHandler { 25 | auth_source, 26 | parameter_provider, 27 | cached_password: Mutex::new(vec![]), 28 | } 29 | } 30 | } 31 | 32 | #[async_trait] 33 | impl StartupHandler 34 | for Md5PasswordAuthStartupHandler 35 | { 36 | async fn on_startup( 37 | &self, 38 | client: &mut C, 39 | message: PgWireFrontendMessage, 40 | ) -> PgWireResult<()> 41 | where 42 | C: ClientInfo + Sink + Unpin + Send, 43 | C::Error: Debug, 44 | PgWireError: From<>::Error>, 45 | { 46 | match message { 47 | PgWireFrontendMessage::Startup(ref startup) => { 48 | super::save_startup_parameters_to_metadata(client, startup); 49 | client.set_state(PgWireConnectionState::AuthenticationInProgress); 50 | 51 | let login_info = LoginInfo::from_client_info(client); 52 | let salt_and_pass = self.auth_source.get_password(&login_info).await?; 53 | 54 | let salt = salt_and_pass 55 | .salt 56 | .as_ref() 57 | .expect("Salt is required for Md5Password authentication"); 58 | 59 | self.cached_password 60 | .lock() 61 | .await 62 | .clone_from(&salt_and_pass.password); 63 | 64 | client 65 | .send(PgWireBackendMessage::Authentication( 66 | Authentication::MD5Password(salt.clone()), 67 | )) 68 | .await?; 69 | } 70 | PgWireFrontendMessage::PasswordMessageFamily(pwd) => { 71 | let pwd = pwd.into_password()?; 72 | let cached_pass = self.cached_password.lock().await; 73 | 74 | if pwd.password.as_bytes() == *cached_pass { 75 | super::finish_authentication(client, self.parameter_provider.as_ref()).await?; 76 | } else { 77 | let login_info = LoginInfo::from_client_info(client); 78 | return Err(PgWireError::InvalidPassword( 79 | login_info.user().map(|x| x.to_owned()).unwrap_or_default(), 80 | )); 81 | } 82 | } 83 | _ => {} 84 | } 85 | Ok(()) 86 | } 87 | } 88 | 89 | /// This function is to compute postgres standard md5 hashed password 90 | /// 91 | /// concat('md5', md5(concat(md5(concat(password, username)), random-salt))) 92 | /// 93 | /// the input parameter `md5hashed_username_password` represents 94 | /// `md5(concat(password, username))` so that your can store hashed password in 95 | /// storage. 96 | pub fn hash_md5_password(username: &str, password: &str, salt: &[u8]) -> String { 97 | let hashed_bytes = format!("{:x}", md5::compute(format!("{password}{username}"))); 98 | let mut bytes = Vec::with_capacity(hashed_bytes.len() + 4); 99 | bytes.extend_from_slice(hashed_bytes.as_ref()); 100 | bytes.extend_from_slice(salt); 101 | 102 | format!("md5{:x}", md5::compute(bytes)) 103 | } 104 | 105 | #[cfg(test)] 106 | mod tests { 107 | 108 | #[test] 109 | fn test_hash_md5_passwd() { 110 | let salt = vec![20, 247, 107, 249]; 111 | let username = "zmjiang"; 112 | let password = "themanwhochangedchina"; 113 | 114 | let result = "md521fe459d77d3e3ea9c9fcd5c11030d30"; 115 | 116 | assert_eq!(result, super::hash_md5_password(username, password, &salt)); 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/api/auth/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::fmt::Debug; 3 | 4 | use async_trait::async_trait; 5 | use futures::sink::{Sink, SinkExt}; 6 | 7 | use super::{ClientInfo, PgWireConnectionState, METADATA_DATABASE, METADATA_USER}; 8 | use crate::error::{PgWireError, PgWireResult}; 9 | use crate::messages::response::{ReadyForQuery, TransactionStatus}; 10 | use crate::messages::startup::{Authentication, BackendKeyData, ParameterStatus, Startup}; 11 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 12 | 13 | /// Handles startup process and frontend messages 14 | #[async_trait] 15 | pub trait StartupHandler: Send + Sync { 16 | /// A generic frontend message callback during startup phase. 17 | async fn on_startup( 18 | &self, 19 | client: &mut C, 20 | message: PgWireFrontendMessage, 21 | ) -> PgWireResult<()> 22 | where 23 | C: ClientInfo + Sink + Unpin + Send, 24 | C::Error: Debug, 25 | PgWireError: From<>::Error>; 26 | } 27 | 28 | pub trait ServerParameterProvider: Send + Sync { 29 | fn server_parameters(&self, _client: &C) -> Option> 30 | where 31 | C: ClientInfo; 32 | } 33 | 34 | /// Default noop parameter provider. 35 | /// 36 | /// This provider responds frontend with default parameters: 37 | /// 38 | /// - `DateStyle: ISO YMD`: the default text serialization in this library is 39 | /// using `YMD` style date. If you override this, or use your own serialization 40 | /// for date types, remember to update this as well. 41 | /// - `server_encoding: UTF8` 42 | /// - `client_encoding: UTF8` 43 | /// - `integer_datetimes: on`: 44 | /// 45 | #[non_exhaustive] 46 | #[derive(Debug)] 47 | pub struct DefaultServerParameterProvider { 48 | pub server_version: String, 49 | pub server_encoding: String, 50 | pub client_encoding: String, 51 | pub date_style: String, 52 | pub integer_datetimes: String, 53 | } 54 | 55 | impl Default for DefaultServerParameterProvider { 56 | fn default() -> Self { 57 | Self { 58 | server_version: format!("16.6-pgwire-{}", env!("CARGO_PKG_VERSION").to_owned()), 59 | server_encoding: "UTF8".to_owned(), 60 | client_encoding: "UTF8".to_owned(), 61 | date_style: "ISO YMD".to_owned(), 62 | integer_datetimes: "on".to_owned(), 63 | } 64 | } 65 | } 66 | 67 | impl ServerParameterProvider for DefaultServerParameterProvider { 68 | fn server_parameters(&self, _client: &C) -> Option> 69 | where 70 | C: ClientInfo, 71 | { 72 | let mut params = HashMap::with_capacity(5); 73 | params.insert("server_version".to_owned(), self.server_version.clone()); 74 | params.insert("server_encoding".to_owned(), self.server_encoding.clone()); 75 | params.insert("client_encoding".to_owned(), self.client_encoding.clone()); 76 | params.insert("DateStyle".to_owned(), self.date_style.clone()); 77 | params.insert( 78 | "integer_datetimes".to_owned(), 79 | self.integer_datetimes.clone(), 80 | ); 81 | 82 | Some(params) 83 | } 84 | } 85 | 86 | #[derive(Debug, new, Clone)] 87 | pub struct Password { 88 | salt: Option>, 89 | password: Vec, 90 | } 91 | 92 | impl Password { 93 | pub fn salt(&self) -> Option<&[u8]> { 94 | self.salt.as_deref() 95 | } 96 | 97 | pub fn password(&self) -> &[u8] { 98 | &self.password 99 | } 100 | } 101 | 102 | #[derive(Debug, new)] 103 | pub struct LoginInfo<'a> { 104 | user: Option<&'a str>, 105 | database: Option<&'a str>, 106 | host: String, 107 | } 108 | 109 | impl LoginInfo<'_> { 110 | pub fn user(&self) -> Option<&str> { 111 | self.user 112 | } 113 | 114 | pub fn database(&self) -> Option<&str> { 115 | self.database 116 | } 117 | 118 | pub fn host(&self) -> &str { 119 | &self.host 120 | } 121 | 122 | pub fn from_client_info(client: &C) -> LoginInfo 123 | where 124 | C: ClientInfo, 125 | { 126 | LoginInfo { 127 | user: client.metadata().get(METADATA_USER).map(|s| s.as_str()), 128 | database: client.metadata().get(METADATA_DATABASE).map(|s| s.as_str()), 129 | host: client.socket_addr().ip().to_string(), 130 | } 131 | } 132 | } 133 | 134 | /// Represents auth source, the source returns password either in cleartext or 135 | /// hashed with salt. 136 | /// 137 | /// When using with different authentication mechanism, the developer can choose 138 | /// specific implementation of `AuthSource`. For example, with cleartext 139 | /// authentication, salt is not required, while in md5pass, a 4-byte salt is 140 | /// needed. 141 | #[async_trait] 142 | pub trait AuthSource: Send + Sync { 143 | /// Get password from the `AuthSource`. 144 | /// 145 | /// `Password` has a an optional salt field when it's hashed. 146 | async fn get_password(&self, login: &LoginInfo) -> PgWireResult; 147 | } 148 | 149 | pub fn save_startup_parameters_to_metadata(client: &mut C, startup_message: &Startup) 150 | where 151 | C: ClientInfo + Sink + Unpin + Send, 152 | C::Error: Debug, 153 | { 154 | client.metadata_mut().extend( 155 | startup_message 156 | .parameters 157 | .iter() 158 | .map(|(k, v)| (k.to_owned(), v.to_owned())), 159 | ); 160 | } 161 | 162 | pub(crate) async fn finish_authentication0( 163 | client: &mut C, 164 | server_parameter_provider: &P, 165 | ) -> PgWireResult<()> 166 | where 167 | C: ClientInfo + Sink + Unpin + Send, 168 | C::Error: Debug, 169 | PgWireError: From<>::Error>, 170 | P: ServerParameterProvider, 171 | { 172 | client 173 | .feed(PgWireBackendMessage::Authentication(Authentication::Ok)) 174 | .await?; 175 | 176 | if let Some(parameters) = server_parameter_provider.server_parameters(client) { 177 | for (k, v) in parameters { 178 | client 179 | .feed(PgWireBackendMessage::ParameterStatus(ParameterStatus::new( 180 | k, v, 181 | ))) 182 | .await?; 183 | } 184 | } 185 | 186 | // TODO: store this backend key 187 | client 188 | .feed(PgWireBackendMessage::BackendKeyData(BackendKeyData::new( 189 | std::process::id() as i32, 190 | rand::random::(), 191 | ))) 192 | .await?; 193 | 194 | Ok(()) 195 | } 196 | 197 | pub async fn finish_authentication( 198 | client: &mut C, 199 | server_parameter_provider: &P, 200 | ) -> PgWireResult<()> 201 | where 202 | C: ClientInfo + Sink + Unpin + Send, 203 | C::Error: Debug, 204 | PgWireError: From<>::Error>, 205 | P: ServerParameterProvider, 206 | { 207 | finish_authentication0(client, server_parameter_provider).await?; 208 | 209 | client 210 | .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( 211 | TransactionStatus::Idle, 212 | ))) 213 | .await?; 214 | 215 | client.set_state(PgWireConnectionState::ReadyForQuery); 216 | Ok(()) 217 | } 218 | 219 | pub mod cleartext; 220 | pub mod md5pass; 221 | pub mod noop; 222 | #[cfg(feature = "scram")] 223 | pub mod scram; 224 | -------------------------------------------------------------------------------- /src/api/auth/noop.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use async_trait::async_trait; 4 | use futures::sink::{Sink, SinkExt}; 5 | 6 | use super::{ClientInfo, DefaultServerParameterProvider, StartupHandler}; 7 | use crate::api::PgWireConnectionState; 8 | use crate::error::{PgWireError, PgWireResult}; 9 | use crate::messages::response::{ReadyForQuery, TransactionStatus}; 10 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 11 | 12 | #[async_trait] 13 | pub trait NoopStartupHandler: StartupHandler { 14 | async fn post_startup( 15 | &self, 16 | _client: &mut C, 17 | _message: PgWireFrontendMessage, 18 | ) -> PgWireResult<()> 19 | where 20 | C: ClientInfo + Sink + Unpin + Send, 21 | C::Error: Debug, 22 | PgWireError: From<>::Error>, 23 | { 24 | Ok(()) 25 | } 26 | } 27 | 28 | #[async_trait] 29 | impl StartupHandler for H 30 | where 31 | H: NoopStartupHandler, 32 | { 33 | async fn on_startup( 34 | &self, 35 | client: &mut C, 36 | message: PgWireFrontendMessage, 37 | ) -> PgWireResult<()> 38 | where 39 | C: ClientInfo + Sink + Unpin + Send, 40 | C::Error: Debug, 41 | PgWireError: From<>::Error>, 42 | { 43 | if let PgWireFrontendMessage::Startup(ref startup) = message { 44 | super::save_startup_parameters_to_metadata(client, startup); 45 | super::finish_authentication0(client, &DefaultServerParameterProvider::default()) 46 | .await?; 47 | 48 | self.post_startup(client, message).await?; 49 | 50 | client 51 | .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( 52 | TransactionStatus::Idle, 53 | ))) 54 | .await?; 55 | client.set_state(PgWireConnectionState::ReadyForQuery); 56 | } 57 | 58 | Ok(()) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/api/client/auth.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeMap; 2 | 3 | use async_trait::async_trait; 4 | use futures::{Sink, SinkExt}; 5 | 6 | use crate::api::auth::md5pass::hash_md5_password; 7 | use crate::error::{ErrorInfo, PgWireClientError, PgWireClientResult}; 8 | use crate::messages::response::ReadyForQuery; 9 | use crate::messages::startup::{ 10 | Authentication, BackendKeyData, ParameterStatus, Password, PasswordMessageFamily, Startup, 11 | }; 12 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 13 | 14 | use super::{ClientInfo, ReadyState, ServerInformation}; 15 | 16 | #[async_trait] 17 | pub trait StartupHandler: Send { 18 | async fn startup(&mut self, client: &mut C) -> PgWireClientResult<()> 19 | where 20 | C: ClientInfo + Sink + Unpin + Send, 21 | PgWireClientError: From<>::Error>; 22 | 23 | async fn on_message( 24 | &mut self, 25 | client: &mut C, 26 | message: PgWireBackendMessage, 27 | ) -> PgWireClientResult> 28 | where 29 | C: ClientInfo + Sink + Unpin + Send, 30 | PgWireClientError: From<>::Error>, 31 | { 32 | match message { 33 | PgWireBackendMessage::Authentication(authentication) => { 34 | self.on_authentication(client, authentication).await?; 35 | } 36 | PgWireBackendMessage::ParameterStatus(parameter_status) => { 37 | self.on_parameter_status(client, parameter_status).await?; 38 | } 39 | PgWireBackendMessage::BackendKeyData(backend_key_data) => { 40 | self.on_backend_key(client, backend_key_data).await?; 41 | } 42 | PgWireBackendMessage::ReadyForQuery(ready) => { 43 | let server_information = self.on_ready_for_query(client, ready).await?; 44 | return Ok(ReadyState::Ready(server_information)); 45 | } 46 | PgWireBackendMessage::ErrorResponse(error) => { 47 | let error_info = ErrorInfo::from(error); 48 | return Err(error_info.into()); 49 | } 50 | PgWireBackendMessage::NoticeResponse(_) => {} 51 | _ => return Err(PgWireClientError::UnexpectedMessage(Box::new(message))), 52 | } 53 | 54 | Ok(ReadyState::Pending) 55 | } 56 | 57 | async fn on_authentication( 58 | &mut self, 59 | client: &mut C, 60 | message: Authentication, 61 | ) -> PgWireClientResult<()> 62 | where 63 | C: ClientInfo + Sink + Unpin + Send, 64 | PgWireClientError: From<>::Error>; 65 | 66 | async fn on_parameter_status( 67 | &mut self, 68 | client: &mut C, 69 | message: ParameterStatus, 70 | ) -> PgWireClientResult<()> 71 | where 72 | C: ClientInfo + Sink + Unpin + Send, 73 | PgWireClientError: From<>::Error>; 74 | 75 | async fn on_backend_key( 76 | &mut self, 77 | client: &mut C, 78 | message: BackendKeyData, 79 | ) -> PgWireClientResult<()> 80 | where 81 | C: ClientInfo + Sink + Unpin + Send, 82 | PgWireClientError: From<>::Error>; 83 | 84 | async fn on_ready_for_query( 85 | &mut self, 86 | client: &mut C, 87 | message: ReadyForQuery, 88 | ) -> PgWireClientResult 89 | where 90 | C: ClientInfo + Sink + Unpin + Send, 91 | PgWireClientError: From<>::Error>; 92 | } 93 | 94 | #[derive(new, Debug)] 95 | pub struct DefaultStartupHandler { 96 | #[new(default)] 97 | server_parameters: BTreeMap, 98 | #[new(default)] 99 | process_id: Option, 100 | } 101 | 102 | #[async_trait] 103 | impl StartupHandler for DefaultStartupHandler { 104 | async fn startup(&mut self, client: &mut C) -> PgWireClientResult<()> 105 | where 106 | C: ClientInfo + Sink + Unpin + Send, 107 | PgWireClientError: From<>::Error>, 108 | { 109 | let mut startup = Startup::new(); 110 | 111 | let config = client.config(); 112 | 113 | if let Some(application_name) = &config.application_name { 114 | startup 115 | .parameters 116 | .insert("application_name".to_string(), application_name.clone()); 117 | } 118 | if let Some(user) = &config.user { 119 | startup.parameters.insert("user".to_string(), user.clone()); 120 | } 121 | if let Some(dbname) = &config.dbname { 122 | startup 123 | .parameters 124 | .insert("database".to_string(), dbname.clone()); 125 | } 126 | 127 | client.send(PgWireFrontendMessage::Startup(startup)).await?; 128 | Ok(()) 129 | } 130 | 131 | async fn on_authentication( 132 | &mut self, 133 | client: &mut C, 134 | message: Authentication, 135 | ) -> PgWireClientResult<()> 136 | where 137 | C: ClientInfo + Sink + Unpin + Send, 138 | PgWireClientError: From<>::Error>, 139 | { 140 | match message { 141 | Authentication::Ok => {} 142 | Authentication::CleartextPassword => { 143 | let pass = client 144 | .config() 145 | .password 146 | .as_ref() 147 | .map(|bs| String::from_utf8_lossy(bs).into_owned()) 148 | .unwrap_or_default(); 149 | 150 | client 151 | .send(PgWireFrontendMessage::PasswordMessageFamily( 152 | PasswordMessageFamily::Password(Password::new(pass)), 153 | )) 154 | .await?; 155 | } 156 | Authentication::MD5Password(salt) => { 157 | let username = client.config().user.as_ref().map_or("", |s| s.as_str()); 158 | 159 | let password = client 160 | .config() 161 | .password 162 | .as_ref() 163 | .map(|bs| String::from_utf8_lossy(bs).into_owned()) 164 | .unwrap_or_default(); 165 | 166 | let hashed_password = hash_md5_password(username, &password, &salt); 167 | client 168 | .send(PgWireFrontendMessage::PasswordMessageFamily( 169 | PasswordMessageFamily::Password(Password::new(hashed_password)), 170 | )) 171 | .await?; 172 | } 173 | // TODO: scram 174 | _ => {} 175 | } 176 | 177 | Ok(()) 178 | } 179 | 180 | async fn on_parameter_status( 181 | &mut self, 182 | _client: &mut C, 183 | message: ParameterStatus, 184 | ) -> PgWireClientResult<()> 185 | where 186 | C: ClientInfo + Sink + Unpin + Send, 187 | { 188 | self.server_parameters.insert(message.name, message.value); 189 | Ok(()) 190 | } 191 | 192 | async fn on_backend_key( 193 | &mut self, 194 | _client: &mut C, 195 | message: BackendKeyData, 196 | ) -> PgWireClientResult<()> 197 | where 198 | C: ClientInfo + Sink + Unpin + Send, 199 | { 200 | self.process_id = Some(message.pid); 201 | Ok(()) 202 | } 203 | 204 | async fn on_ready_for_query( 205 | &mut self, 206 | _client: &mut C, 207 | _message: ReadyForQuery, 208 | ) -> PgWireClientResult 209 | where 210 | C: ClientInfo + Sink + Unpin + Send, 211 | { 212 | Ok(ServerInformation { 213 | parameters: self.server_parameters.clone(), 214 | process_id: self.process_id.unwrap_or(-1), 215 | }) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/api/client/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod auth; 2 | pub(crate) mod config; 3 | pub mod query; 4 | pub mod result; 5 | 6 | use std::collections::BTreeMap; 7 | 8 | pub use config::Config; 9 | 10 | /// A trait for fetching necessary information from Client 11 | pub trait ClientInfo { 12 | /// Returns configuration of this client 13 | fn config(&self) -> &Config; 14 | 15 | /// Returns server parameters received from server 16 | fn server_parameters(&self) -> &BTreeMap; 17 | 18 | /// Returns process id received from server 19 | fn process_id(&self) -> i32; 20 | 21 | // TODO: transaction state 22 | } 23 | 24 | /// Carries server provided information for current connection 25 | #[derive(Debug, Default)] 26 | pub struct ServerInformation { 27 | pub parameters: BTreeMap, 28 | pub process_id: i32, 29 | } 30 | 31 | /// Indicate the result of current request 32 | pub enum ReadyState { 33 | Pending, 34 | Ready(D), 35 | } 36 | -------------------------------------------------------------------------------- /src/api/client/query.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use async_trait::async_trait; 4 | use futures::{Sink, SinkExt}; 5 | use postgres_types::Oid; 6 | 7 | use crate::api::results::{FieldInfo, Tag}; 8 | use crate::error::{ErrorInfo, PgWireClientError, PgWireClientResult}; 9 | use crate::messages::data::{DataRow, RowDescription}; 10 | use crate::messages::response::{CommandComplete, EmptyQueryResponse, ReadyForQuery}; 11 | use crate::messages::simplequery::Query; 12 | use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; 13 | 14 | use super::result::DataRowsReader; 15 | use super::{ClientInfo, ReadyState}; 16 | 17 | #[async_trait] 18 | pub trait SimpleQueryHandler: Send { 19 | type QueryResponse; 20 | 21 | async fn simple_query(&mut self, client: &mut C, query: &str) -> PgWireClientResult<()> 22 | where 23 | C: ClientInfo + Sink + Unpin + Send, 24 | PgWireClientError: From<>::Error>; 25 | 26 | async fn on_message( 27 | &mut self, 28 | client: &mut C, 29 | message: PgWireBackendMessage, 30 | ) -> PgWireClientResult>> 31 | where 32 | C: ClientInfo + Sink + Unpin + Send, 33 | PgWireClientError: From<>::Error>, 34 | { 35 | match message { 36 | PgWireBackendMessage::RowDescription(row_description) => { 37 | self.on_row_description(client, row_description).await?; 38 | } 39 | PgWireBackendMessage::DataRow(data_row) => { 40 | self.on_data_row(client, data_row).await?; 41 | } 42 | PgWireBackendMessage::CommandComplete(command_complete) => { 43 | self.on_command_complete(client, command_complete).await?; 44 | } 45 | PgWireBackendMessage::EmptyQueryResponse(empty_query) => { 46 | self.on_empty_query(client, empty_query).await?; 47 | } 48 | PgWireBackendMessage::ReadyForQuery(ready_for_query) => { 49 | let response = self.on_ready_for_query(client, ready_for_query).await?; 50 | return Ok(ReadyState::Ready(response)); 51 | } 52 | PgWireBackendMessage::ErrorResponse(error) => { 53 | let error_info = ErrorInfo::from(error); 54 | return Err(error_info.into()); 55 | } 56 | PgWireBackendMessage::NoticeResponse(_) => {} 57 | _ => return Err(PgWireClientError::UnexpectedMessage(Box::new(message))), 58 | } 59 | 60 | Ok(ReadyState::Pending) 61 | } 62 | 63 | async fn on_row_description( 64 | &mut self, 65 | client: &mut C, 66 | message: RowDescription, 67 | ) -> PgWireClientResult<()> 68 | where 69 | C: ClientInfo + Sink + Unpin + Send, 70 | PgWireClientError: From<>::Error>; 71 | 72 | async fn on_data_row(&mut self, client: &mut C, message: DataRow) -> PgWireClientResult<()> 73 | where 74 | C: ClientInfo + Sink + Unpin + Send, 75 | PgWireClientError: From<>::Error>; 76 | 77 | async fn on_command_complete( 78 | &mut self, 79 | client: &mut C, 80 | message: CommandComplete, 81 | ) -> PgWireClientResult<()> 82 | where 83 | C: ClientInfo + Sink + Unpin + Send, 84 | PgWireClientError: From<>::Error>; 85 | 86 | async fn on_empty_query( 87 | &mut self, 88 | client: &mut C, 89 | message: EmptyQueryResponse, 90 | ) -> PgWireClientResult<()> 91 | where 92 | C: ClientInfo + Sink + Unpin + Send, 93 | PgWireClientError: From<>::Error>; 94 | 95 | async fn on_ready_for_query( 96 | &mut self, 97 | client: &mut C, 98 | message: ReadyForQuery, 99 | ) -> PgWireClientResult> 100 | where 101 | C: ClientInfo + Sink + Unpin + Send, 102 | PgWireClientError: From<>::Error>; 103 | } 104 | 105 | #[derive(Debug)] 106 | pub enum Response { 107 | EmptyQuery, 108 | Query((Tag, Vec, Vec)), 109 | Execution(Tag), 110 | } 111 | 112 | impl Response { 113 | pub fn into_data_rows_reader(self) -> DataRowsReader { 114 | if let Response::Query((_, fields, rows)) = self { 115 | DataRowsReader::new(fields, rows) 116 | } else { 117 | DataRowsReader::empty() 118 | } 119 | } 120 | } 121 | 122 | impl FromStr for Tag { 123 | type Err = PgWireClientError; 124 | 125 | fn from_str(s: &str) -> Result { 126 | let segs = s.split_whitespace().collect::>(); 127 | if segs.len() == 2 { 128 | let rows = segs[1] 129 | .parse::() 130 | .map_err(|e| PgWireClientError::InvalidTag(Box::new(e)))?; 131 | Ok(Tag::new(segs[0]).with_rows(rows)) 132 | } else if segs.len() == 3 { 133 | let rows = segs[1] 134 | .parse::() 135 | .map_err(|e| PgWireClientError::InvalidTag(Box::new(e)))?; 136 | let oid = segs[2] 137 | .parse::() 138 | .map_err(|e| PgWireClientError::InvalidTag(Box::new(e)))?; 139 | Ok(Tag::new(segs[0]).with_rows(rows).with_oid(oid)) 140 | } else { 141 | Ok(Tag::new(s)) 142 | } 143 | } 144 | } 145 | 146 | struct QueryResponseBuffer { 147 | row_schema: Vec, 148 | data_rows: Vec, 149 | } 150 | 151 | #[derive(Default, new)] 152 | pub struct DefaultSimpleQueryHandler { 153 | #[new(default)] 154 | current_buffer: Option, 155 | #[new(default)] 156 | responses: Vec, 157 | } 158 | 159 | #[async_trait] 160 | impl SimpleQueryHandler for DefaultSimpleQueryHandler { 161 | type QueryResponse = Response; 162 | 163 | async fn simple_query(&mut self, client: &mut C, query: &str) -> PgWireClientResult<()> 164 | where 165 | C: ClientInfo + Sink + Unpin + Send, 166 | PgWireClientError: From<>::Error>, 167 | { 168 | let query = Query::new(query.to_owned()); 169 | client.send(PgWireFrontendMessage::Query(query)).await?; 170 | Ok(()) 171 | } 172 | 173 | async fn on_row_description( 174 | &mut self, 175 | _client: &mut C, 176 | message: RowDescription, 177 | ) -> PgWireClientResult<()> 178 | where 179 | C: ClientInfo + Sink + Unpin + Send, 180 | PgWireClientError: From<>::Error>, 181 | { 182 | let fields = message.fields.into_iter().map(|f| f.into()).collect(); 183 | let buffer = QueryResponseBuffer { 184 | row_schema: fields, 185 | data_rows: Vec::new(), 186 | }; 187 | self.current_buffer = Some(buffer); 188 | Ok(()) 189 | } 190 | 191 | async fn on_data_row(&mut self, _client: &mut C, message: DataRow) -> PgWireClientResult<()> 192 | where 193 | C: ClientInfo + Sink + Unpin + Send, 194 | PgWireClientError: From<>::Error>, 195 | { 196 | if let Some(ref mut current_buffer) = self.current_buffer { 197 | current_buffer.data_rows.push(message); 198 | Ok(()) 199 | } else { 200 | Err(PgWireClientError::UnexpectedMessage(Box::new( 201 | PgWireBackendMessage::DataRow(message), 202 | ))) 203 | } 204 | } 205 | 206 | async fn on_command_complete( 207 | &mut self, 208 | _client: &mut C, 209 | message: CommandComplete, 210 | ) -> PgWireClientResult<()> 211 | where 212 | C: ClientInfo + Sink + Unpin + Send, 213 | PgWireClientError: From<>::Error>, 214 | { 215 | if self.current_buffer.is_some() { 216 | let current_buffer = std::mem::take(&mut self.current_buffer); 217 | let current_buffer = current_buffer.unwrap(); 218 | self.responses.push(Response::Query(( 219 | message.tag.parse::()?, 220 | current_buffer.row_schema, 221 | current_buffer.data_rows, 222 | ))); 223 | } else { 224 | let tag = message.tag.parse::()?; 225 | self.responses.push(Response::Execution(tag)); 226 | } 227 | 228 | Ok(()) 229 | } 230 | 231 | async fn on_empty_query( 232 | &mut self, 233 | _client: &mut C, 234 | _message: EmptyQueryResponse, 235 | ) -> PgWireClientResult<()> 236 | where 237 | C: ClientInfo + Sink + Unpin + Send, 238 | PgWireClientError: From<>::Error>, 239 | { 240 | self.responses.push(Response::EmptyQuery); 241 | Ok(()) 242 | } 243 | 244 | async fn on_ready_for_query( 245 | &mut self, 246 | _client: &mut C, 247 | _message: ReadyForQuery, 248 | ) -> PgWireClientResult> 249 | where 250 | C: ClientInfo + Sink + Unpin + Send, 251 | PgWireClientError: From<>::Error>, 252 | { 253 | let responses = std::mem::take(&mut self.responses); 254 | Ok(responses) 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /src/api/client/result.rs: -------------------------------------------------------------------------------- 1 | use bytes::Buf; 2 | use postgres_types::FromSqlOwned; 3 | 4 | use crate::api::results::{FieldFormat, FieldInfo}; 5 | use crate::error::{PgWireClientError, PgWireClientResult}; 6 | use crate::messages::data::DataRow; 7 | use crate::types::FromSqlText; 8 | 9 | #[derive(new, Debug)] 10 | pub struct DataRowsReader { 11 | fields: Vec, 12 | rows: Vec, 13 | } 14 | 15 | impl DataRowsReader { 16 | pub fn empty() -> DataRowsReader { 17 | Self { 18 | fields: vec![], 19 | rows: vec![], 20 | } 21 | } 22 | 23 | /// Generate row decoder for next row 24 | pub fn next_row(&mut self) -> Option> { 25 | if !self.rows.is_empty() { 26 | let row = self.rows.remove(0); 27 | Some(DataRowDecoder::new(self.fields.as_slice(), row)) 28 | } else { 29 | None 30 | } 31 | } 32 | } 33 | 34 | #[derive(new, Debug)] 35 | pub struct DataRowDecoder<'a> { 36 | fields: &'a [FieldInfo], 37 | row: DataRow, 38 | #[new(default)] 39 | read_index: usize, 40 | } 41 | 42 | impl DataRowDecoder<'_> { 43 | /// Get value from data row 44 | pub fn next_value(&mut self) -> PgWireClientResult> 45 | where 46 | T: FromSqlOwned + FromSqlText, 47 | { 48 | if let Some(field_info) = self.fields.get(self.read_index) { 49 | // advance read index 50 | self.read_index += 1; 51 | 52 | let byte_len = self.row.data.get_i16(); 53 | if byte_len < 0 { 54 | Ok(None) 55 | } else { 56 | let bytes = self.row.data.split_to(byte_len as usize); 57 | 58 | if field_info.format() == FieldFormat::Text { 59 | T::from_sql_text(field_info.datatype(), bytes.as_ref()) 60 | .map_err(PgWireClientError::FromSqlError) 61 | .map(Some) 62 | } else { 63 | // binary 64 | T::from_sql(field_info.datatype(), bytes.as_ref()) 65 | .map_err(PgWireClientError::FromSqlError) 66 | .map(Some) 67 | } 68 | } 69 | } else { 70 | Err(PgWireClientError::DataRowIndexOutOfBounds) 71 | } 72 | } 73 | 74 | /// Length of fields 75 | pub fn len(&self) -> usize { 76 | self.fields.len() 77 | } 78 | 79 | pub fn is_empty(&self) -> bool { 80 | self.fields.is_empty() 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests {} 86 | -------------------------------------------------------------------------------- /src/api/copy.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use futures::sink::{Sink, SinkExt}; 3 | use std::fmt::Debug; 4 | 5 | use crate::error::{ErrorInfo, PgWireError, PgWireResult}; 6 | use crate::messages::copy::{ 7 | CopyBothResponse, CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, 8 | }; 9 | use crate::messages::PgWireBackendMessage; 10 | 11 | use super::results::CopyResponse; 12 | use super::ClientInfo; 13 | 14 | /// handler for copy messages 15 | #[async_trait] 16 | pub trait CopyHandler: Send + Sync { 17 | async fn on_copy_data(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()> 18 | where 19 | C: ClientInfo + Sink + Unpin + Send + Sync, 20 | C::Error: Debug, 21 | PgWireError: From<>::Error>, 22 | { 23 | Ok(()) 24 | } 25 | 26 | async fn on_copy_done(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()> 27 | where 28 | C: ClientInfo + Sink + Unpin + Send + Sync, 29 | C::Error: Debug, 30 | PgWireError: From<>::Error>, 31 | { 32 | Ok(()) 33 | } 34 | 35 | async fn on_copy_fail(&self, _client: &mut C, fail: CopyFail) -> PgWireError 36 | where 37 | C: ClientInfo + Sink + Unpin + Send + Sync, 38 | C::Error: Debug, 39 | PgWireError: From<>::Error>, 40 | { 41 | PgWireError::UserError(Box::new(ErrorInfo::new( 42 | "ERROR".to_owned(), 43 | "XX000".to_owned(), 44 | format!("COPY IN mode terminated by the user: {}", fail.message), 45 | ))) 46 | } 47 | } 48 | 49 | pub async fn send_copy_in_response(client: &mut C, resp: CopyResponse) -> PgWireResult<()> 50 | where 51 | C: ClientInfo + Sink + Unpin + Send + Sync, 52 | C::Error: Debug, 53 | PgWireError: From<>::Error>, 54 | { 55 | let resp = CopyInResponse::new(resp.format, resp.columns as i16, resp.column_formats); 56 | client 57 | .send(PgWireBackendMessage::CopyInResponse(resp)) 58 | .await?; 59 | Ok(()) 60 | } 61 | 62 | pub async fn send_copy_out_response(client: &mut C, resp: CopyResponse) -> PgWireResult<()> 63 | where 64 | C: ClientInfo + Sink + Unpin + Send + Sync, 65 | C::Error: Debug, 66 | PgWireError: From<>::Error>, 67 | { 68 | let resp = CopyOutResponse::new(resp.format, resp.columns as i16, resp.column_formats); 69 | client 70 | .send(PgWireBackendMessage::CopyOutResponse(resp)) 71 | .await?; 72 | Ok(()) 73 | } 74 | 75 | pub async fn send_copy_both_response(client: &mut C, resp: CopyResponse) -> PgWireResult<()> 76 | where 77 | C: ClientInfo + Sink + Unpin + Send + Sync, 78 | C::Error: Debug, 79 | PgWireError: From<>::Error>, 80 | { 81 | let resp = CopyBothResponse::new(resp.format, resp.columns as i16, resp.column_formats); 82 | client 83 | .send(PgWireBackendMessage::CopyBothResponse(resp)) 84 | .await?; 85 | Ok(()) 86 | } 87 | 88 | #[derive(Clone, Copy, Debug, Default)] 89 | pub struct NoopCopyHandler; 90 | 91 | impl CopyHandler for NoopCopyHandler {} 92 | -------------------------------------------------------------------------------- /src/api/mod.rs: -------------------------------------------------------------------------------- 1 | //! APIs for building postgresql compatible servers. 2 | 3 | use std::collections::HashMap; 4 | use std::net::SocketAddr; 5 | use std::sync::Arc; 6 | 7 | pub use postgres_types::Type; 8 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 9 | use rustls_pki_types::CertificateDer; 10 | 11 | use crate::error::PgWireError; 12 | use crate::messages::response::TransactionStatus; 13 | 14 | pub mod auth; 15 | #[cfg(feature = "client-api")] 16 | pub mod client; 17 | pub mod copy; 18 | pub mod portal; 19 | pub mod query; 20 | pub mod results; 21 | pub mod stmt; 22 | pub mod store; 23 | pub mod transaction; 24 | 25 | pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME"; 26 | 27 | #[derive(Debug, Clone, Copy, Default)] 28 | pub enum PgWireConnectionState { 29 | #[default] 30 | AwaitingSslRequest, 31 | AwaitingStartup, 32 | AuthenticationInProgress, 33 | ReadyForQuery, 34 | QueryInProgress, 35 | CopyInProgress(bool), 36 | AwaitingSync, 37 | } 38 | 39 | /// Describe a client information holder 40 | pub trait ClientInfo { 41 | fn socket_addr(&self) -> SocketAddr; 42 | 43 | fn is_secure(&self) -> bool; 44 | 45 | fn state(&self) -> PgWireConnectionState; 46 | 47 | fn set_state(&mut self, new_state: PgWireConnectionState); 48 | 49 | fn transaction_status(&self) -> TransactionStatus; 50 | 51 | fn set_transaction_status(&mut self, new_status: TransactionStatus); 52 | 53 | fn metadata(&self) -> &HashMap; 54 | 55 | fn metadata_mut(&mut self) -> &mut HashMap; 56 | 57 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 58 | fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]>; 59 | } 60 | 61 | /// Client Portal Store 62 | pub trait ClientPortalStore { 63 | type PortalStore; 64 | 65 | fn portal_store(&self) -> &Self::PortalStore; 66 | } 67 | 68 | pub const METADATA_USER: &str = "user"; 69 | pub const METADATA_DATABASE: &str = "database"; 70 | 71 | #[non_exhaustive] 72 | #[derive(Debug)] 73 | pub struct DefaultClient { 74 | pub socket_addr: SocketAddr, 75 | pub is_secure: bool, 76 | pub state: PgWireConnectionState, 77 | pub transaction_status: TransactionStatus, 78 | pub metadata: HashMap, 79 | pub portal_store: store::MemPortalStore, 80 | } 81 | 82 | impl ClientInfo for DefaultClient { 83 | fn socket_addr(&self) -> SocketAddr { 84 | self.socket_addr 85 | } 86 | 87 | fn is_secure(&self) -> bool { 88 | self.is_secure 89 | } 90 | 91 | fn state(&self) -> PgWireConnectionState { 92 | self.state 93 | } 94 | 95 | fn set_state(&mut self, new_state: PgWireConnectionState) { 96 | self.state = new_state; 97 | } 98 | 99 | fn metadata(&self) -> &HashMap { 100 | &self.metadata 101 | } 102 | 103 | fn metadata_mut(&mut self) -> &mut HashMap { 104 | &mut self.metadata 105 | } 106 | 107 | fn transaction_status(&self) -> TransactionStatus { 108 | self.transaction_status 109 | } 110 | 111 | fn set_transaction_status(&mut self, new_status: TransactionStatus) { 112 | self.transaction_status = new_status 113 | } 114 | 115 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 116 | fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> { 117 | None 118 | } 119 | } 120 | 121 | impl DefaultClient { 122 | pub fn new(socket_addr: SocketAddr, is_secure: bool) -> DefaultClient { 123 | DefaultClient { 124 | socket_addr, 125 | is_secure, 126 | state: PgWireConnectionState::default(), 127 | transaction_status: TransactionStatus::Idle, 128 | metadata: HashMap::new(), 129 | portal_store: store::MemPortalStore::new(), 130 | } 131 | } 132 | } 133 | 134 | impl ClientPortalStore for DefaultClient { 135 | type PortalStore = store::MemPortalStore; 136 | 137 | fn portal_store(&self) -> &Self::PortalStore { 138 | &self.portal_store 139 | } 140 | } 141 | 142 | /// A centralized handler for all errors 143 | /// 144 | /// This handler captures all errors produces by authentication, query and 145 | /// copy. You can do logging, filtering or masking the error before it sent to 146 | /// client. 147 | pub trait ErrorHandler: Send + Sync { 148 | fn on_error(&self, _client: &C, _error: &mut PgWireError) 149 | where 150 | C: ClientInfo, 151 | { 152 | } 153 | } 154 | 155 | /// A noop implementation for `ErrorHandler`. 156 | pub struct NoopErrorHandler; 157 | 158 | impl ErrorHandler for NoopErrorHandler {} 159 | 160 | pub trait PgWireServerHandlers { 161 | type StartupHandler: auth::StartupHandler; 162 | type SimpleQueryHandler: query::SimpleQueryHandler; 163 | type ExtendedQueryHandler: query::ExtendedQueryHandler; 164 | type CopyHandler: copy::CopyHandler; 165 | type ErrorHandler: ErrorHandler; 166 | 167 | fn simple_query_handler(&self) -> Arc; 168 | 169 | fn extended_query_handler(&self) -> Arc; 170 | 171 | fn startup_handler(&self) -> Arc; 172 | 173 | fn copy_handler(&self) -> Arc; 174 | 175 | fn error_handler(&self) -> Arc; 176 | } 177 | 178 | impl PgWireServerHandlers for Arc 179 | where 180 | T: PgWireServerHandlers, 181 | { 182 | type StartupHandler = T::StartupHandler; 183 | type SimpleQueryHandler = T::SimpleQueryHandler; 184 | type ExtendedQueryHandler = T::ExtendedQueryHandler; 185 | type CopyHandler = T::CopyHandler; 186 | type ErrorHandler = T::ErrorHandler; 187 | 188 | fn simple_query_handler(&self) -> Arc { 189 | (**self).simple_query_handler() 190 | } 191 | 192 | fn extended_query_handler(&self) -> Arc { 193 | (**self).extended_query_handler() 194 | } 195 | 196 | fn startup_handler(&self) -> Arc { 197 | (**self).startup_handler() 198 | } 199 | 200 | fn copy_handler(&self) -> Arc { 201 | (**self).copy_handler() 202 | } 203 | 204 | fn error_handler(&self) -> Arc { 205 | (**self).error_handler() 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /src/api/portal.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use bytes::Bytes; 4 | use postgres_types::FromSqlOwned; 5 | 6 | use crate::{ 7 | api::Type, 8 | error::{PgWireError, PgWireResult}, 9 | messages::{data::FORMAT_CODE_BINARY, extendedquery::Bind}, 10 | }; 11 | 12 | use super::{results::FieldFormat, stmt::StoredStatement, DEFAULT_NAME}; 13 | 14 | /// Represent a prepared sql statement and its parameters bound by a `Bind` 15 | /// request. 16 | #[non_exhaustive] 17 | #[derive(Debug, Default, Clone)] 18 | pub struct Portal { 19 | pub name: String, 20 | pub statement: Arc>, 21 | pub parameter_format: Format, 22 | pub parameters: Vec>, 23 | pub result_column_format: Format, 24 | } 25 | 26 | #[derive(Debug, Clone, Default)] 27 | pub enum Format { 28 | #[default] 29 | UnifiedText, 30 | UnifiedBinary, 31 | Individual(Vec), 32 | } 33 | 34 | impl From for Format { 35 | fn from(v: i16) -> Format { 36 | if v == FORMAT_CODE_BINARY { 37 | Format::UnifiedBinary 38 | } else { 39 | Format::UnifiedText 40 | } 41 | } 42 | } 43 | 44 | impl Format { 45 | /// Get format code for given index 46 | pub fn format_for(&self, idx: usize) -> FieldFormat { 47 | match self { 48 | Format::UnifiedText => FieldFormat::Text, 49 | Format::UnifiedBinary => FieldFormat::Binary, 50 | Format::Individual(ref fv) => FieldFormat::from(fv[idx]), 51 | } 52 | } 53 | 54 | /// Test if `idx` field is text format 55 | pub fn is_text(&self, idx: usize) -> bool { 56 | self.format_for(idx) == FieldFormat::Text 57 | } 58 | 59 | /// Test if `idx` field is binary format 60 | pub fn is_binary(&self, idx: usize) -> bool { 61 | self.format_for(idx) == FieldFormat::Binary 62 | } 63 | 64 | fn from_codes(codes: &[i16]) -> Self { 65 | if codes.is_empty() { 66 | Format::UnifiedText 67 | } else if codes.len() == 1 { 68 | Format::from(codes[0]) 69 | } else { 70 | Format::Individual(codes.to_vec()) 71 | } 72 | } 73 | } 74 | 75 | impl Portal { 76 | /// Try to create portal from bind command and current client state 77 | pub fn try_new(bind: &Bind, statement: Arc>) -> PgWireResult { 78 | let portal_name = bind 79 | .portal_name 80 | .clone() 81 | .unwrap_or_else(|| DEFAULT_NAME.to_owned()); 82 | 83 | // param format 84 | let param_format = Format::from_codes(&bind.parameter_format_codes); 85 | 86 | // format 87 | let result_format = Format::from_codes(&bind.result_column_format_codes); 88 | 89 | Ok(Portal { 90 | name: portal_name, 91 | statement, 92 | parameter_format: param_format, 93 | parameters: bind.parameters.clone(), 94 | result_column_format: result_format, 95 | }) 96 | } 97 | 98 | /// Get number of parameters 99 | pub fn parameter_len(&self) -> usize { 100 | self.parameters.len() 101 | } 102 | 103 | /// Attempt to get parameter at given index as type `T`. 104 | /// 105 | pub fn parameter(&self, idx: usize, pg_type: &Type) -> PgWireResult> 106 | where 107 | T: FromSqlOwned, 108 | { 109 | if !T::accepts(pg_type) { 110 | return Err(PgWireError::InvalidRustTypeForParameter( 111 | pg_type.name().to_owned(), 112 | )); 113 | } 114 | 115 | let param = self 116 | .parameters 117 | .get(idx) 118 | .ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?; 119 | 120 | let _format = self.parameter_format.format_for(idx); 121 | 122 | if let Some(ref param) = param { 123 | // TODO: from_sql only works with binary format 124 | // here we need to check format code first and seek to support text 125 | T::from_sql(pg_type, param) 126 | .map(|v| Some(v)) 127 | .map_err(PgWireError::FailedToParseParameter) 128 | } else { 129 | // Null 130 | Ok(None) 131 | } 132 | } 133 | } 134 | 135 | #[cfg(test)] 136 | mod tests { 137 | use postgres_types::FromSql; 138 | 139 | use super::*; 140 | 141 | #[test] 142 | fn test_from_sql() { 143 | assert_eq!( 144 | "helloworld", 145 | String::from_sql(&Type::UNKNOWN, "helloworld".as_bytes()).unwrap() 146 | ) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/api/stmt.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | use futures::Sink; 5 | use postgres_types::Type; 6 | 7 | use crate::messages::extendedquery::Parse; 8 | use crate::{error::PgWireResult, messages::PgWireBackendMessage}; 9 | 10 | use super::{ClientInfo, DEFAULT_NAME}; 11 | 12 | #[non_exhaustive] 13 | #[derive(Debug, Default, new)] 14 | pub struct StoredStatement { 15 | /// name of the statement 16 | pub id: String, 17 | /// parsed query statement 18 | pub statement: S, 19 | /// type ids of query parameters, can be empty if frontend asks backend for 20 | /// type inference 21 | pub parameter_types: Vec, 22 | } 23 | 24 | impl StoredStatement { 25 | pub(crate) async fn parse( 26 | client: &C, 27 | parse: &Parse, 28 | parser: Q, 29 | ) -> PgWireResult> 30 | where 31 | C: ClientInfo + Sink + Unpin + Send + Sync, 32 | Q: QueryParser, 33 | { 34 | let types = parse 35 | .type_oids 36 | .iter() 37 | .map(|oid| Type::from_oid(*oid).unwrap_or(Type::UNKNOWN)) 38 | .collect::>(); 39 | let statement = parser.parse_sql(client, &parse.query, &types).await?; 40 | Ok(StoredStatement { 41 | id: parse 42 | .name 43 | .clone() 44 | .unwrap_or_else(|| DEFAULT_NAME.to_owned()), 45 | statement, 46 | parameter_types: types, 47 | }) 48 | } 49 | } 50 | 51 | /// Trait for sql parser. The parser transforms string query into its statement 52 | /// type. 53 | #[async_trait] 54 | pub trait QueryParser { 55 | type Statement; 56 | 57 | async fn parse_sql( 58 | &self, 59 | client: &C, 60 | sql: &str, 61 | types: &[Type], 62 | ) -> PgWireResult 63 | where 64 | C: ClientInfo + Unpin + Send + Sync; 65 | } 66 | 67 | #[async_trait] 68 | impl QueryParser for Arc 69 | where 70 | QP: QueryParser + Send + Sync, 71 | { 72 | type Statement = QP::Statement; 73 | 74 | async fn parse_sql( 75 | &self, 76 | client: &C, 77 | sql: &str, 78 | types: &[Type], 79 | ) -> PgWireResult 80 | where 81 | C: ClientInfo + Unpin + Send + Sync, 82 | { 83 | (**self).parse_sql(client, sql, types).await 84 | } 85 | } 86 | 87 | /// A demo parser implementation. Never use it in serious application. 88 | #[derive(new, Debug, Default)] 89 | pub struct NoopQueryParser; 90 | 91 | #[async_trait] 92 | impl QueryParser for NoopQueryParser { 93 | type Statement = String; 94 | 95 | async fn parse_sql( 96 | &self, 97 | _client: &C, 98 | sql: &str, 99 | _types: &[Type], 100 | ) -> PgWireResult 101 | where 102 | C: ClientInfo + Unpin + Send + Sync, 103 | { 104 | Ok(sql.to_owned()) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/api/store.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeMap; 2 | use std::sync::{Arc, RwLock}; 3 | 4 | use super::portal::Portal; 5 | use super::stmt::StoredStatement; 6 | 7 | pub trait PortalStore: Send + Sync { 8 | type Statement; 9 | 10 | fn put_statement(&self, statement: Arc>); 11 | 12 | fn rm_statement(&self, name: &str); 13 | 14 | fn get_statement(&self, name: &str) -> Option>>; 15 | 16 | fn put_portal(&self, portal: Arc>); 17 | 18 | fn rm_portal(&self, name: &str); 19 | 20 | fn get_portal(&self, name: &str) -> Option>>; 21 | } 22 | 23 | #[derive(Debug, Default, new)] 24 | pub struct MemPortalStore { 25 | #[new(default)] 26 | statements: RwLock>>>, 27 | #[new(default)] 28 | portals: RwLock>>>, 29 | } 30 | 31 | impl PortalStore for MemPortalStore { 32 | type Statement = S; 33 | 34 | fn put_statement(&self, statement: Arc>) { 35 | let mut guard = self.statements.write().unwrap(); 36 | guard.insert(statement.id.to_owned(), statement); 37 | } 38 | 39 | fn rm_statement(&self, name: &str) { 40 | let mut guard = self.statements.write().unwrap(); 41 | guard.remove(name); 42 | } 43 | 44 | fn get_statement(&self, name: &str) -> Option>> { 45 | let guard = self.statements.read().unwrap(); 46 | guard.get(name).cloned() 47 | } 48 | 49 | fn put_portal(&self, portal: Arc>) { 50 | let mut guard = self.portals.write().unwrap(); 51 | guard.insert(portal.name.to_owned(), portal); 52 | } 53 | 54 | fn rm_portal(&self, name: &str) { 55 | let mut guard = self.portals.write().unwrap(); 56 | guard.remove(name); 57 | } 58 | 59 | fn get_portal(&self, name: &str) -> Option>> { 60 | let guard = self.portals.read().unwrap(); 61 | guard.get(name).cloned() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/api/transaction.rs: -------------------------------------------------------------------------------- 1 | use crate::messages::response::TransactionStatus; 2 | 3 | impl TransactionStatus { 4 | pub fn to_idle_state(self) -> TransactionStatus { 5 | TransactionStatus::Idle 6 | } 7 | 8 | pub fn to_error_state(self) -> TransactionStatus { 9 | match self { 10 | TransactionStatus::Idle => TransactionStatus::Idle, 11 | _ => TransactionStatus::Error, 12 | } 13 | } 14 | 15 | pub fn to_in_transaction_state(self) -> TransactionStatus { 16 | match self { 17 | TransactionStatus::Idle => TransactionStatus::Transaction, 18 | TransactionStatus::Transaction => TransactionStatus::Transaction, 19 | TransactionStatus::Error => TransactionStatus::Error, 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # pgwire 2 | //! 3 | //! `pgwire` provides the PostgreSQL wire protocol as a library for 4 | //! implementing PostgreSQL-compatible servers and clients. 5 | //! [`rust-postgres`](https://crates.io/crates/postgres) will be sufficient 6 | //! for most Postgres client use-cases, so this library focuses on 7 | //! server development. 8 | //! 9 | //! ## About Postgres Wire Protocol 10 | //! 11 | //! Postgres Wire Protocol is a relatively general-purpose Layer-7 12 | //! protocol. There are 3 parts of the protocol: 13 | //! 14 | //! - Startup: client-server handshake and authentication. 15 | //! - Simple Query: The legacy query protocol of postgresql. Query are provided 16 | //! as string, and server is allowed to stream data in response. 17 | //! - Extended Query: A new sub-protocol for query which has ability to cache 18 | //! the query on server-side and reuse it with new parameters. The response part 19 | //! is identical to Simple Query. 20 | //! 21 | //! Also note that Postgres Wire Protocol has no semantics about SQL, so 22 | //! literally you can use any query language, data formats or even natural 23 | //! language to interact with the backend. 24 | //! 25 | //! The response are always encoded as data row format. And there is a field 26 | //! description as header of the data to describe its name, type and format. 27 | //! 28 | //! ## Components 29 | //! 30 | //! There are two main components in postgresql wire protocol: **startup** and 31 | //! **query**. In **query**, there are two subprotocols: the legacy text-based 32 | //! **simple query** and binary **extended query**. 33 | //! 34 | //! ## Layered API 35 | //! 36 | //! pgwire provides three layers of abstractions that allows you to compose your 37 | //! application from any level of abstraction. They are: 38 | //! 39 | //! - Protocol layer: Just use message definitions and codecs in `messages` 40 | //! module. 41 | //! - Message handler layer: Implement `on_` prefixed methods in traits: 42 | //! - `StartupHandler` 43 | //! - `SimpleQueryHandler` 44 | //! - `ExtendedQueryHandler` 45 | //! - High-level API layer 46 | //! - `AuthSource` and various authentication mechanisms 47 | //! - `do_` prefixed methods in handler traits 48 | //! - `QueryParser`/`PortalStore` for extended query support 49 | //! 50 | //! ## Features 51 | //! 52 | //! - `server-api-aws-lc-rs` is enabled by default, it includes all three layers 53 | //! of our API and uses `aws-lc-rs` as crypto backend. 54 | //! - `server-api-ring` is almost same to `server-api-aws-lc-rs` except for it's 55 | //! using `ring` as crypto backend. 56 | //! - `scram` for the SASL/SCRAM authenticator. 57 | //! - Turn off default features if you just use our Protocol layer. 58 | //! 59 | //! ## Examples 60 | //! 61 | //! [Examples](https://github.com/sunng87/pgwire) are provided to demo API 62 | //! usages. 63 | //! 64 | 65 | #[macro_use] 66 | extern crate derive_new; 67 | 68 | /// handler layer and high-level API layer. 69 | #[cfg(any(feature = "server-api", feature = "client-api"))] 70 | pub mod api; 71 | /// error types. 72 | pub mod error; 73 | /// the protocol layer. 74 | pub mod messages; 75 | /// server entry-point for tokio based application. 76 | #[cfg(any(feature = "server-api", feature = "client-api"))] 77 | pub mod tokio; 78 | /// types and encoding related helper 79 | #[cfg(feature = "server-api")] 80 | pub mod types; 81 | -------------------------------------------------------------------------------- /src/messages/codec.rs: -------------------------------------------------------------------------------- 1 | use std::str; 2 | 3 | use bytes::{Buf, BufMut, BytesMut}; 4 | 5 | use crate::error::PgWireResult; 6 | 7 | /// Get null-terminated string, returns None when empty cstring read. 8 | /// 9 | /// Note that this implementation will also advance cursor by 1 after reading 10 | /// empty cstring. This behaviour works for how postgres wire protocol handling 11 | /// key-value pairs, which is ended by a single `\0` 12 | pub(crate) fn get_cstring(buf: &mut BytesMut) -> Option { 13 | let mut i = 0; 14 | 15 | if buf.remaining() == 0 { 16 | return None; 17 | } 18 | 19 | // with bound check to prevent invalid format 20 | while i < buf.remaining() && buf[i] != b'\0' { 21 | i += 1; 22 | } 23 | 24 | // i+1: include the '\0' 25 | // move cursor to the end of cstring 26 | let string_buf = buf.split_to(i + 1); 27 | 28 | if i == 0 { 29 | None 30 | } else { 31 | Some(String::from_utf8_lossy(&string_buf[..i]).into_owned()) 32 | } 33 | } 34 | 35 | /// Put null-termianted string 36 | /// 37 | /// You can put empty string by giving `""` as input. 38 | pub(crate) fn put_cstring(buf: &mut BytesMut, input: &str) { 39 | buf.put_slice(input.as_bytes()); 40 | buf.put_u8(b'\0'); 41 | } 42 | 43 | pub(crate) fn put_option_cstring(buf: &mut BytesMut, input: &Option) { 44 | if let Some(input) = input { 45 | put_cstring(buf, input); 46 | } else { 47 | buf.put_u8(b'\0'); 48 | } 49 | } 50 | 51 | /// Try to read message length from buf, without actually move the cursor 52 | pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Option { 53 | if buf.remaining() >= 4 + offset { 54 | Some((&buf[offset..4 + offset]).get_i32() as usize) 55 | } else { 56 | None 57 | } 58 | } 59 | 60 | /// Check if message_length matches and move the cursor to right position then 61 | /// call the `decode_fn` for the body 62 | pub(crate) fn decode_packet( 63 | buf: &mut BytesMut, 64 | offset: usize, 65 | decode_fn: F, 66 | ) -> PgWireResult> 67 | where 68 | F: Fn(&mut BytesMut, usize) -> PgWireResult, 69 | { 70 | if let Some(msg_len) = get_length(buf, offset) { 71 | if buf.remaining() >= msg_len + offset { 72 | buf.advance(offset + 4); 73 | return decode_fn(buf, msg_len).map(|r| Some(r)); 74 | } 75 | } 76 | 77 | Ok(None) 78 | } 79 | 80 | // pub(crate) fn get_and_ensure_message_type(buf: &mut BytesMut, t: u8) -> PgWireResult<()> { 81 | // let msg_type = buf[0]; 82 | // // ensure the type is corrent 83 | // if msg_type != t { 84 | // return Err(PgWireError::InvalidMessageType(t, msg_type)); 85 | // } 86 | 87 | // Ok(()) 88 | // } 89 | 90 | pub(crate) fn option_string_len(s: &Option) -> usize { 91 | 1 + s.as_ref().map(|s| s.len()).unwrap_or(0) 92 | } 93 | 94 | #[cfg(test)] 95 | mod test { 96 | use super::get_cstring; 97 | use bytes::{BufMut, BytesMut}; 98 | 99 | #[test] 100 | fn get_cstring_valid() { 101 | let mut buf = BytesMut::new(); 102 | buf.put(&b"a cstring\0"[..]); 103 | buf.put(&b"\0"[..]); 104 | 105 | assert_eq!(Some("a cstring".into()), get_cstring(&mut buf)); 106 | assert_eq!(None, get_cstring(&mut buf)); 107 | } 108 | 109 | #[test] 110 | fn get_cstring_empty() { 111 | let mut buf = BytesMut::new(); 112 | 113 | assert_eq!(None, get_cstring(&mut buf)); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/messages/copy.rs: -------------------------------------------------------------------------------- 1 | use bytes::{Buf, BufMut, Bytes, BytesMut}; 2 | 3 | use super::codec; 4 | use super::Message; 5 | use crate::error::PgWireResult; 6 | 7 | pub const MESSAGE_TYPE_BYTE_COPY_DATA: u8 = b'd'; 8 | 9 | #[non_exhaustive] 10 | #[derive(PartialEq, Eq, Debug, Default, new)] 11 | pub struct CopyData { 12 | pub data: Bytes, 13 | } 14 | 15 | impl Message for CopyData { 16 | #[inline] 17 | fn message_type() -> Option { 18 | Some(MESSAGE_TYPE_BYTE_COPY_DATA) 19 | } 20 | 21 | fn message_length(&self) -> usize { 22 | 4 + self.data.len() 23 | } 24 | 25 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 26 | buf.put(self.data.as_ref()); 27 | Ok(()) 28 | } 29 | 30 | fn decode_body(buf: &mut BytesMut, len: usize) -> PgWireResult { 31 | let data = buf.split_to(len - 4).freeze(); 32 | Ok(Self::new(data)) 33 | } 34 | } 35 | 36 | pub const MESSAGE_TYPE_BYTE_COPY_DONE: u8 = b'c'; 37 | 38 | #[non_exhaustive] 39 | #[derive(PartialEq, Eq, Debug, Default, new)] 40 | pub struct CopyDone; 41 | 42 | impl Message for CopyDone { 43 | #[inline] 44 | fn message_type() -> Option { 45 | Some(MESSAGE_TYPE_BYTE_COPY_DONE) 46 | } 47 | 48 | fn message_length(&self) -> usize { 49 | 4 50 | } 51 | 52 | fn encode_body(&self, _buf: &mut BytesMut) -> PgWireResult<()> { 53 | Ok(()) 54 | } 55 | 56 | fn decode_body(_buf: &mut BytesMut, _len: usize) -> PgWireResult { 57 | Ok(Self::new()) 58 | } 59 | } 60 | 61 | pub const MESSAGE_TYPE_BYTE_COPY_FAIL: u8 = b'f'; 62 | 63 | #[non_exhaustive] 64 | #[derive(PartialEq, Eq, Debug, Default, new)] 65 | pub struct CopyFail { 66 | pub message: String, 67 | } 68 | 69 | impl Message for CopyFail { 70 | #[inline] 71 | fn message_type() -> Option { 72 | Some(MESSAGE_TYPE_BYTE_COPY_DONE) 73 | } 74 | 75 | fn message_length(&self) -> usize { 76 | 4 + self.message.len() + 1 77 | } 78 | 79 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 80 | codec::put_cstring(buf, &self.message); 81 | Ok(()) 82 | } 83 | 84 | fn decode_body(buf: &mut BytesMut, _len: usize) -> PgWireResult { 85 | let msg = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 86 | Ok(Self::new(msg)) 87 | } 88 | } 89 | 90 | pub const MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE: u8 = b'G'; 91 | 92 | #[non_exhaustive] 93 | #[derive(PartialEq, Eq, Debug, Default, new)] 94 | pub struct CopyInResponse { 95 | pub format: i8, 96 | pub columns: i16, 97 | pub column_formats: Vec, 98 | } 99 | 100 | impl Message for CopyInResponse { 101 | #[inline] 102 | fn message_type() -> Option { 103 | Some(MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE) 104 | } 105 | 106 | fn message_length(&self) -> usize { 107 | 4 + 1 + 2 + self.column_formats.len() * 2 108 | } 109 | 110 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 111 | buf.put_i8(self.format); 112 | buf.put_i16(self.columns); 113 | for cf in &self.column_formats { 114 | buf.put_i16(*cf); 115 | } 116 | Ok(()) 117 | } 118 | 119 | fn decode_body(buf: &mut BytesMut, _len: usize) -> PgWireResult { 120 | let format = buf.get_i8(); 121 | let columns = buf.get_i16(); 122 | let mut column_formats = Vec::with_capacity(columns as usize); 123 | for _ in 0..columns { 124 | column_formats.push(buf.get_i16()); 125 | } 126 | 127 | Ok(Self::new(format, columns, column_formats)) 128 | } 129 | } 130 | 131 | pub const MESSAGE_TYPE_BYTE_COPY_OUT_RESPONSE: u8 = b'H'; 132 | 133 | #[non_exhaustive] 134 | #[derive(PartialEq, Eq, Debug, Default, new)] 135 | pub struct CopyOutResponse { 136 | pub format: i8, 137 | pub columns: i16, 138 | pub column_formats: Vec, 139 | } 140 | 141 | impl Message for CopyOutResponse { 142 | #[inline] 143 | fn message_type() -> Option { 144 | Some(MESSAGE_TYPE_BYTE_COPY_OUT_RESPONSE) 145 | } 146 | 147 | fn message_length(&self) -> usize { 148 | 4 + 1 + 2 + self.column_formats.len() * 2 149 | } 150 | 151 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 152 | buf.put_i8(self.format); 153 | buf.put_i16(self.columns); 154 | for cf in &self.column_formats { 155 | buf.put_i16(*cf); 156 | } 157 | Ok(()) 158 | } 159 | 160 | fn decode_body(buf: &mut BytesMut, _len: usize) -> PgWireResult { 161 | let format = buf.get_i8(); 162 | let columns = buf.get_i16(); 163 | let mut column_formats = Vec::with_capacity(columns as usize); 164 | for _ in 0..columns { 165 | column_formats.push(buf.get_i16()); 166 | } 167 | 168 | Ok(Self::new(format, columns, column_formats)) 169 | } 170 | } 171 | 172 | pub const MESSAGE_TYPE_BYTE_COPY_BOTH_RESPONSE: u8 = b'W'; 173 | 174 | #[non_exhaustive] 175 | #[derive(PartialEq, Eq, Debug, Default, new)] 176 | pub struct CopyBothResponse { 177 | pub format: i8, 178 | pub columns: i16, 179 | pub column_formats: Vec, 180 | } 181 | 182 | impl Message for CopyBothResponse { 183 | #[inline] 184 | fn message_type() -> Option { 185 | Some(MESSAGE_TYPE_BYTE_COPY_BOTH_RESPONSE) 186 | } 187 | 188 | fn message_length(&self) -> usize { 189 | 4 + 1 + 2 + self.column_formats.len() * 2 190 | } 191 | 192 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 193 | buf.put_i8(self.format); 194 | buf.put_i16(self.columns); 195 | for cf in &self.column_formats { 196 | buf.put_i16(*cf); 197 | } 198 | Ok(()) 199 | } 200 | 201 | fn decode_body(buf: &mut BytesMut, _len: usize) -> PgWireResult { 202 | let format = buf.get_i8(); 203 | let columns = buf.get_i16(); 204 | let mut column_formats = Vec::with_capacity(columns as usize); 205 | for _ in 0..columns { 206 | column_formats.push(buf.get_i16()); 207 | } 208 | 209 | Ok(Self::new(format, columns, column_formats)) 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/messages/data.rs: -------------------------------------------------------------------------------- 1 | use bytes::{Buf, BufMut, BytesMut}; 2 | 3 | use super::codec; 4 | use super::Message; 5 | use crate::error::PgWireResult; 6 | 7 | pub const FORMAT_CODE_TEXT: i16 = 0; 8 | pub const FORMAT_CODE_BINARY: i16 = 1; 9 | 10 | #[non_exhaustive] 11 | #[derive(PartialEq, Eq, Debug, Default, new)] 12 | pub struct FieldDescription { 13 | // the field name 14 | pub name: String, 15 | // the object ID of table, default to 0 if not a table 16 | pub table_id: i32, 17 | // the attribute number of the column, default to 0 if not a column from table 18 | pub column_id: i16, 19 | // the object ID of the data type 20 | pub type_id: u32, 21 | // the size of data type, negative values denote variable-width types 22 | pub type_size: i16, 23 | // the type modifier 24 | pub type_modifier: i32, 25 | // the format code being used for the filed, will be 0 or 1 for now 26 | pub format_code: i16, 27 | } 28 | 29 | #[non_exhaustive] 30 | #[derive(PartialEq, Eq, Debug, Default, new)] 31 | pub struct RowDescription { 32 | pub fields: Vec, 33 | } 34 | 35 | pub const MESSAGE_TYPE_BYTE_ROW_DESCRITION: u8 = b'T'; 36 | 37 | impl Message for RowDescription { 38 | fn message_type() -> Option { 39 | Some(MESSAGE_TYPE_BYTE_ROW_DESCRITION) 40 | } 41 | 42 | fn message_length(&self) -> usize { 43 | 4 + 2 44 | + self 45 | .fields 46 | .iter() 47 | .map(|f| f.name.len() + 1 + 4 + 2 + 4 + 2 + 4 + 2) 48 | .sum::() 49 | } 50 | 51 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 52 | buf.put_i16(self.fields.len() as i16); 53 | 54 | for field in &self.fields { 55 | codec::put_cstring(buf, &field.name); 56 | buf.put_i32(field.table_id); 57 | buf.put_i16(field.column_id); 58 | buf.put_u32(field.type_id); 59 | buf.put_i16(field.type_size); 60 | buf.put_i32(field.type_modifier); 61 | buf.put_i16(field.format_code); 62 | } 63 | 64 | Ok(()) 65 | } 66 | 67 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 68 | let fields_len = buf.get_i16(); 69 | let mut fields = Vec::with_capacity(fields_len as usize); 70 | 71 | for _ in 0..fields_len { 72 | let field = FieldDescription { 73 | name: codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()), 74 | table_id: buf.get_i32(), 75 | column_id: buf.get_i16(), 76 | type_id: buf.get_u32(), 77 | type_size: buf.get_i16(), 78 | type_modifier: buf.get_i32(), 79 | format_code: buf.get_i16(), 80 | }; 81 | 82 | fields.push(field); 83 | } 84 | 85 | Ok(RowDescription { fields }) 86 | } 87 | } 88 | 89 | /// Data structure returned when frontend describes a statement 90 | #[non_exhaustive] 91 | #[derive(PartialEq, Eq, Debug, Default, new, Clone)] 92 | pub struct ParameterDescription { 93 | /// parameter types 94 | pub types: Vec, 95 | } 96 | 97 | pub const MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION: u8 = b't'; 98 | 99 | impl Message for ParameterDescription { 100 | fn message_type() -> Option { 101 | Some(MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION) 102 | } 103 | 104 | fn message_length(&self) -> usize { 105 | 4 + 2 + self.types.len() * 4 106 | } 107 | 108 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 109 | buf.put_u16(self.types.len() as u16); 110 | 111 | for t in &self.types { 112 | buf.put_i32(*t as i32); 113 | } 114 | 115 | Ok(()) 116 | } 117 | 118 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 119 | let types_len = buf.get_u16(); 120 | let mut types = Vec::with_capacity(types_len as usize); 121 | 122 | for _ in 0..types_len { 123 | types.push(buf.get_i32() as u32); 124 | } 125 | 126 | Ok(ParameterDescription { types }) 127 | } 128 | } 129 | 130 | /// Data structure for postgresql wire protocol `DataRow` message. 131 | /// 132 | /// Data can be represented as text or binary format as specified by format 133 | /// codes from previous `RowDescription` message. 134 | #[non_exhaustive] 135 | #[derive(PartialEq, Eq, Debug, Default, new, Clone)] 136 | pub struct DataRow { 137 | pub data: BytesMut, 138 | pub field_count: i16, 139 | } 140 | 141 | impl DataRow {} 142 | 143 | pub const MESSAGE_TYPE_BYTE_DATA_ROW: u8 = b'D'; 144 | 145 | impl Message for DataRow { 146 | #[inline] 147 | fn message_type() -> Option { 148 | Some(MESSAGE_TYPE_BYTE_DATA_ROW) 149 | } 150 | 151 | fn message_length(&self) -> usize { 152 | 4 + 2 + self.data.len() 153 | } 154 | 155 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 156 | buf.put_i16(self.field_count); 157 | buf.reserve(self.data.len()); 158 | buf.put_slice(&self.data); 159 | 160 | Ok(()) 161 | } 162 | 163 | fn decode_body(buf: &mut BytesMut, msg_len: usize) -> PgWireResult { 164 | let field_count = buf.get_i16(); 165 | // get body size from packet 166 | let data = buf.split_to(msg_len - 4 - 2); 167 | 168 | Ok(DataRow { data, field_count }) 169 | } 170 | } 171 | 172 | /// postgres response when query returns no data, sent from backend to frontend 173 | /// in extended query 174 | #[non_exhaustive] 175 | #[derive(PartialEq, Eq, Debug, Default, new)] 176 | pub struct NoData; 177 | 178 | pub const MESSAGE_TYPE_BYTE_NO_DATA: u8 = b'n'; 179 | 180 | impl Message for NoData { 181 | #[inline] 182 | fn message_type() -> Option { 183 | Some(MESSAGE_TYPE_BYTE_NO_DATA) 184 | } 185 | 186 | fn message_length(&self) -> usize { 187 | 4 188 | } 189 | 190 | fn encode_body(&self, _buf: &mut BytesMut) -> PgWireResult<()> { 191 | Ok(()) 192 | } 193 | 194 | fn decode_body(_buf: &mut BytesMut, _: usize) -> PgWireResult { 195 | Ok(NoData::new()) 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /src/messages/response.rs: -------------------------------------------------------------------------------- 1 | use bytes::{Buf, BufMut, BytesMut}; 2 | 3 | use super::codec; 4 | use super::Message; 5 | use crate::error::{PgWireError, PgWireResult}; 6 | 7 | #[non_exhaustive] 8 | #[derive(PartialEq, Eq, Debug, new)] 9 | pub struct CommandComplete { 10 | pub tag: String, 11 | } 12 | 13 | pub const MESSAGE_TYPE_BYTE_COMMAND_COMPLETE: u8 = b'C'; 14 | 15 | impl Message for CommandComplete { 16 | #[inline] 17 | fn message_type() -> Option { 18 | Some(MESSAGE_TYPE_BYTE_COMMAND_COMPLETE) 19 | } 20 | 21 | fn message_length(&self) -> usize { 22 | 5 + self.tag.len() 23 | } 24 | 25 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 26 | codec::put_cstring(buf, &self.tag); 27 | 28 | Ok(()) 29 | } 30 | 31 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 32 | let tag = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 33 | 34 | Ok(CommandComplete::new(tag)) 35 | } 36 | } 37 | 38 | #[non_exhaustive] 39 | #[derive(PartialEq, Eq, Debug, new)] 40 | pub struct EmptyQueryResponse; 41 | 42 | pub const MESSAGE_TYPE_BYTE_EMPTY_QUERY_RESPONSE: u8 = b'I'; 43 | 44 | impl Message for EmptyQueryResponse { 45 | fn message_type() -> Option { 46 | Some(MESSAGE_TYPE_BYTE_EMPTY_QUERY_RESPONSE) 47 | } 48 | 49 | fn message_length(&self) -> usize { 50 | 4 51 | } 52 | 53 | fn encode_body(&self, _buf: &mut BytesMut) -> PgWireResult<()> { 54 | Ok(()) 55 | } 56 | 57 | fn decode_body(_buf: &mut BytesMut, _full_len: usize) -> PgWireResult { 58 | Ok(EmptyQueryResponse) 59 | } 60 | } 61 | 62 | #[non_exhaustive] 63 | #[derive(PartialEq, Eq, Debug, new)] 64 | pub struct ReadyForQuery { 65 | pub status: TransactionStatus, 66 | } 67 | 68 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 69 | #[repr(u8)] 70 | pub enum TransactionStatus { 71 | Idle = READY_STATUS_IDLE, 72 | Transaction = READY_STATUS_TRANSACTION_BLOCK, 73 | Error = READY_STATUS_FAILED_TRANSACTION_BLOCK, 74 | } 75 | 76 | pub const READY_STATUS_IDLE: u8 = b'I'; 77 | pub const READY_STATUS_TRANSACTION_BLOCK: u8 = b'T'; 78 | pub const READY_STATUS_FAILED_TRANSACTION_BLOCK: u8 = b'E'; 79 | 80 | pub const MESSAGE_TYPE_BYTE_READY_FOR_QUERY: u8 = b'Z'; 81 | 82 | impl Message for ReadyForQuery { 83 | #[inline] 84 | fn message_type() -> Option { 85 | Some(MESSAGE_TYPE_BYTE_READY_FOR_QUERY) 86 | } 87 | 88 | #[inline] 89 | fn message_length(&self) -> usize { 90 | 5 91 | } 92 | 93 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 94 | buf.put_u8(self.status as u8); 95 | 96 | Ok(()) 97 | } 98 | 99 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 100 | let status = TransactionStatus::try_from(buf.get_u8())?; 101 | Ok(ReadyForQuery::new(status)) 102 | } 103 | } 104 | 105 | impl TryFrom for TransactionStatus { 106 | type Error = PgWireError; 107 | fn try_from(value: u8) -> Result { 108 | match value { 109 | READY_STATUS_IDLE => Ok(Self::Idle), 110 | READY_STATUS_TRANSACTION_BLOCK => Ok(Self::Transaction), 111 | READY_STATUS_FAILED_TRANSACTION_BLOCK => Ok(Self::Error), 112 | _ => Err(PgWireError::InvalidTransactionStatus(value)), 113 | } 114 | } 115 | } 116 | 117 | /// postgres error response, sent from backend to frontend 118 | #[non_exhaustive] 119 | #[derive(PartialEq, Eq, Debug, Default, new)] 120 | pub struct ErrorResponse { 121 | pub fields: Vec<(u8, String)>, 122 | } 123 | 124 | pub const MESSAGE_TYPE_BYTE_ERROR_RESPONSE: u8 = b'E'; 125 | 126 | impl Message for ErrorResponse { 127 | #[inline] 128 | fn message_type() -> Option { 129 | Some(MESSAGE_TYPE_BYTE_ERROR_RESPONSE) 130 | } 131 | 132 | fn message_length(&self) -> usize { 133 | 4 + self.fields.iter().map(|f| 1 + f.1.len() + 1).sum::() + 1 134 | } 135 | 136 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 137 | for (code, value) in &self.fields { 138 | buf.put_u8(*code); 139 | codec::put_cstring(buf, value); 140 | } 141 | 142 | buf.put_u8(b'\0'); 143 | 144 | Ok(()) 145 | } 146 | 147 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 148 | let mut fields = Vec::new(); 149 | loop { 150 | let code = buf.get_u8(); 151 | 152 | if code == b'\0' { 153 | return Ok(ErrorResponse { fields }); 154 | } else { 155 | let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 156 | fields.push((code, value)); 157 | } 158 | } 159 | } 160 | } 161 | 162 | /// postgres error response, sent from backend to frontend 163 | #[non_exhaustive] 164 | #[derive(PartialEq, Eq, Debug, Default, new)] 165 | pub struct NoticeResponse { 166 | pub fields: Vec<(u8, String)>, 167 | } 168 | 169 | pub const MESSAGE_TYPE_BYTE_NOTICE_RESPONSE: u8 = b'N'; 170 | 171 | impl Message for NoticeResponse { 172 | #[inline] 173 | fn message_type() -> Option { 174 | Some(MESSAGE_TYPE_BYTE_NOTICE_RESPONSE) 175 | } 176 | 177 | fn message_length(&self) -> usize { 178 | 4 + self.fields.iter().map(|f| 1 + f.1.len() + 1).sum::() + 1 179 | } 180 | 181 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 182 | for (code, value) in &self.fields { 183 | buf.put_u8(*code); 184 | codec::put_cstring(buf, value); 185 | } 186 | 187 | buf.put_u8(b'\0'); 188 | 189 | Ok(()) 190 | } 191 | 192 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 193 | let mut fields = Vec::new(); 194 | loop { 195 | let code = buf.get_u8(); 196 | 197 | if code == b'\0' { 198 | return Ok(NoticeResponse { fields }); 199 | } else { 200 | let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 201 | fields.push((code, value)); 202 | } 203 | } 204 | } 205 | } 206 | 207 | /// Response to SSLRequest. 208 | /// 209 | /// To initiate an SSL-encrypted connection, the frontend initially sends an 210 | /// SSLRequest message rather than a StartupMessage. The server then responds 211 | /// with a single byte containing 'S' or 'N', indicating that it is willing or 212 | /// unwilling to perform SSL, respectively. 213 | #[non_exhaustive] 214 | #[derive(Debug, PartialEq)] 215 | pub enum SslResponse { 216 | Accept, 217 | Refuse, 218 | } 219 | 220 | impl SslResponse { 221 | pub const BYTE_ACCEPT: u8 = b'S'; 222 | pub const BYTE_REFUSE: u8 = b'N'; 223 | // The whole message takes only one byte and has no size field. 224 | pub const MESSAGE_LENGTH: usize = 1; 225 | } 226 | 227 | impl Message for SslResponse { 228 | fn message_length(&self) -> usize { 229 | Self::MESSAGE_LENGTH 230 | } 231 | 232 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 233 | match self { 234 | Self::Accept => buf.put_u8(Self::BYTE_ACCEPT), 235 | Self::Refuse => buf.put_u8(Self::BYTE_REFUSE), 236 | } 237 | Ok(()) 238 | } 239 | 240 | fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> { 241 | self.encode_body(buf) 242 | } 243 | 244 | fn decode_body(_: &mut BytesMut, _: usize) -> PgWireResult { 245 | unreachable!() 246 | } 247 | 248 | fn decode(buf: &mut BytesMut) -> PgWireResult> { 249 | if buf.remaining() >= Self::MESSAGE_LENGTH { 250 | match buf[0] { 251 | Self::BYTE_ACCEPT => { 252 | buf.advance(Self::MESSAGE_LENGTH); 253 | Ok(Some(SslResponse::Accept)) 254 | } 255 | Self::BYTE_REFUSE => { 256 | buf.advance(Self::MESSAGE_LENGTH); 257 | Ok(Some(SslResponse::Refuse)) 258 | } 259 | _ => Ok(None), 260 | } 261 | } else { 262 | Ok(None) 263 | } 264 | } 265 | } 266 | 267 | /// NotificationResponse 268 | #[non_exhaustive] 269 | #[derive(PartialEq, Eq, Debug, Default, new)] 270 | pub struct NotificationResponse { 271 | pub pid: i32, 272 | pub channel: String, 273 | pub payload: String, 274 | } 275 | 276 | pub const MESSAGE_TYPE_BYTE_NOTIFICATION_RESPONSE: u8 = b'A'; 277 | 278 | impl Message for NotificationResponse { 279 | #[inline] 280 | fn message_type() -> Option { 281 | Some(MESSAGE_TYPE_BYTE_NOTIFICATION_RESPONSE) 282 | } 283 | 284 | fn message_length(&self) -> usize { 285 | 8 + self.channel.len() + 1 + self.payload.len() + 1 286 | } 287 | 288 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 289 | buf.put_i32(self.pid); 290 | codec::put_cstring(buf, &self.channel); 291 | codec::put_cstring(buf, &self.payload); 292 | 293 | Ok(()) 294 | } 295 | 296 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 297 | let pid = buf.get_i32(); 298 | let channel = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 299 | let payload = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 300 | 301 | Ok(NotificationResponse { 302 | pid, 303 | channel, 304 | payload, 305 | }) 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /src/messages/simplequery.rs: -------------------------------------------------------------------------------- 1 | use bytes::BytesMut; 2 | 3 | use super::codec; 4 | use super::Message; 5 | use crate::error::PgWireResult; 6 | 7 | /// A sql query sent from frontend to backend. 8 | #[non_exhaustive] 9 | #[derive(PartialEq, Eq, Debug, new)] 10 | pub struct Query { 11 | pub query: String, 12 | } 13 | 14 | pub const MESSAGE_TYPE_BYTE_QUERY: u8 = b'Q'; 15 | 16 | impl Message for Query { 17 | #[inline] 18 | fn message_type() -> Option { 19 | Some(MESSAGE_TYPE_BYTE_QUERY) 20 | } 21 | 22 | fn message_length(&self) -> usize { 23 | 5 + self.query.len() 24 | } 25 | 26 | fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { 27 | codec::put_cstring(buf, &self.query); 28 | 29 | Ok(()) 30 | } 31 | 32 | fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { 33 | let query = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); 34 | 35 | Ok(Query::new(query)) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/messages/terminate.rs: -------------------------------------------------------------------------------- 1 | use super::Message; 2 | use crate::error::PgWireResult; 3 | 4 | #[non_exhaustive] 5 | #[derive(Default, PartialEq, Eq, Debug, new)] 6 | pub struct Terminate; 7 | 8 | pub const MESSAGE_TYPE_BYTE_TERMINATE: u8 = b'X'; 9 | 10 | impl Message for Terminate { 11 | #[inline] 12 | fn message_type() -> Option { 13 | Some(MESSAGE_TYPE_BYTE_TERMINATE) 14 | } 15 | 16 | #[inline] 17 | fn message_length(&self) -> usize { 18 | 4 19 | } 20 | 21 | fn encode_body(&self, _: &mut bytes::BytesMut) -> PgWireResult<()> { 22 | Ok(()) 23 | } 24 | 25 | fn decode_body(_: &mut bytes::BytesMut, _: usize) -> PgWireResult { 26 | Ok(Terminate) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/tokio/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "client-api")] 2 | pub mod client; 3 | 4 | #[cfg(feature = "server-api")] 5 | mod server; 6 | 7 | #[cfg(feature = "server-api")] 8 | pub use server::process_socket; 9 | 10 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 11 | pub use tokio_rustls; 12 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 13 | pub type TlsAcceptor = tokio_rustls::TlsAcceptor; 14 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 15 | pub type TlsConnector = tokio_rustls::TlsConnector; 16 | 17 | #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] 18 | pub(super) const POSTGRESQL_ALPN_NAME: &[u8] = b"postgresql"; 19 | 20 | #[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))] 21 | pub enum TlsAcceptor {} 22 | #[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))] 23 | pub enum TlsConnector {} 24 | -------------------------------------------------------------------------------- /src/types/from_sql_text.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | use std::time::{Duration, SystemTime, UNIX_EPOCH}; 4 | 5 | use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Offset, Utc}; 6 | use postgres_types::{Type, WrongType}; 7 | use rust_decimal::Decimal; 8 | 9 | pub trait FromSqlText: fmt::Debug { 10 | /// Converts value from postgres text format to rust. 11 | /// 12 | /// This trait is modelled after `FromSql` from postgres-types, which is 13 | /// for binary encoding. 14 | fn from_sql_text(ty: &Type, input: &[u8]) -> Result> 15 | where 16 | Self: Sized; 17 | } 18 | 19 | fn to_str(f: &[u8]) -> Result<&str, Box> { 20 | std::str::from_utf8(f).map_err(Into::into) 21 | } 22 | 23 | impl FromSqlText for bool { 24 | fn from_sql_text(_ty: &Type, input: &[u8]) -> Result> 25 | where 26 | Self: Sized, 27 | { 28 | match input { 29 | b"t" => Ok(true), 30 | b"f" => Ok(false), 31 | _ => Err("Invalid text value for bool".into()), 32 | } 33 | } 34 | } 35 | 36 | impl FromSqlText for String { 37 | fn from_sql_text(_ty: &Type, input: &[u8]) -> Result> 38 | where 39 | Self: Sized, 40 | { 41 | to_str(input).map(|s| s.to_owned()) 42 | } 43 | } 44 | 45 | macro_rules! impl_from_sql_text { 46 | ($t:ty) => { 47 | impl FromSqlText for $t { 48 | fn from_sql_text( 49 | _ty: &Type, 50 | input: &[u8], 51 | ) -> Result> { 52 | to_str(input).and_then(|s| s.parse::<$t>().map_err(Into::into)) 53 | } 54 | } 55 | }; 56 | } 57 | 58 | impl_from_sql_text!(i8); 59 | impl_from_sql_text!(i16); 60 | impl_from_sql_text!(i32); 61 | impl_from_sql_text!(i64); 62 | impl_from_sql_text!(u32); 63 | impl_from_sql_text!(f32); 64 | impl_from_sql_text!(f64); 65 | impl_from_sql_text!(char); 66 | 67 | impl FromSqlText for Decimal { 68 | fn from_sql_text(_ty: &Type, input: &[u8]) -> Result> 69 | where 70 | Self: Sized, 71 | { 72 | Decimal::from_str_exact(to_str(input)?).map_err(Into::into) 73 | } 74 | } 75 | 76 | impl FromSqlText for Vec { 77 | fn from_sql_text(_ty: &Type, input: &[u8]) -> Result> 78 | where 79 | Self: Sized, 80 | { 81 | let data = input 82 | .strip_prefix(b"\\x") 83 | .ok_or("\\x prefix expected for bytea")?; 84 | 85 | hex::decode(data).map_err(|e| e.to_string().into()) 86 | } 87 | } 88 | 89 | impl FromSqlText for SystemTime { 90 | fn from_sql_text(_ty: &Type, value: &[u8]) -> Result> 91 | where 92 | Self: Sized, 93 | { 94 | let datetime = NaiveDateTime::parse_from_str(to_str(value)?, "%Y-%m-%d %H:%M:%S.6f")?; 95 | let system_time = 96 | UNIX_EPOCH + Duration::from_millis(datetime.and_utc().timestamp_millis() as u64); 97 | 98 | Ok(system_time) 99 | } 100 | } 101 | 102 | impl FromSqlText for DateTime { 103 | fn from_sql_text(ty: &Type, value: &[u8]) -> Result> 104 | where 105 | Self: Sized, 106 | { 107 | match *ty { 108 | Type::TIMESTAMP | Type::TIMESTAMP_ARRAY => { 109 | let fmt = "%Y-%m-%d %H:%M:%S%.6f"; 110 | let datetime = NaiveDateTime::parse_from_str(to_str(value)?, fmt)?; 111 | 112 | Ok(DateTime::from_naive_utc_and_offset(datetime, Utc.fix())) 113 | } 114 | Type::TIMESTAMPTZ | Type::TIMESTAMPTZ_ARRAY => { 115 | let fmt = "%Y-%m-%d %H:%M:%S%.6f%:::z"; 116 | let datetime = DateTime::parse_from_str(to_str(value)?, fmt)?; 117 | Ok(datetime) 118 | } 119 | Type::DATE | Type::DATE_ARRAY => { 120 | let fmt = "%Y-%m-%d"; 121 | let datetime = NaiveDateTime::parse_from_str(to_str(value)?, fmt)?; 122 | Ok(DateTime::from_naive_utc_and_offset(datetime, Utc.fix())) 123 | } 124 | _ => Err(Box::new(WrongType::new::>(ty.clone()))), 125 | } 126 | } 127 | } 128 | 129 | impl FromSqlText for NaiveDate { 130 | fn from_sql_text(_ty: &Type, value: &[u8]) -> Result> 131 | where 132 | Self: Sized, 133 | { 134 | let date = NaiveDate::parse_from_str(to_str(value)?, "%Y-%m-%d")?; 135 | Ok(date) 136 | } 137 | } 138 | 139 | impl FromSqlText for NaiveTime { 140 | fn from_sql_text(_ty: &Type, value: &[u8]) -> Result> 141 | where 142 | Self: Sized, 143 | { 144 | let time = NaiveTime::parse_from_str(to_str(value)?, "%H:%M:%S")?; 145 | Ok(time) 146 | } 147 | } 148 | 149 | impl FromSqlText for NaiveDateTime { 150 | fn from_sql_text(_ty: &Type, value: &[u8]) -> Result> 151 | where 152 | Self: Sized, 153 | { 154 | let datetime = NaiveDateTime::parse_from_str(to_str(value)?, "%Y-%m-%d %H:%M:%S")?; 155 | Ok(datetime) 156 | } 157 | } 158 | 159 | impl FromSqlText for Option 160 | where 161 | T: FromSqlText, 162 | { 163 | fn from_sql_text(ty: &Type, input: &[u8]) -> Result> 164 | where 165 | Self: Sized, 166 | { 167 | if input.is_empty() { 168 | Ok(None) 169 | } else { 170 | T::from_sql_text(ty, input).map(Some) 171 | } 172 | } 173 | } 174 | 175 | //TODO: array types 176 | 177 | #[cfg(test)] 178 | mod tests { 179 | use super::*; 180 | 181 | #[test] 182 | fn test_from_sql_text_for_string() { 183 | let sql_text = "Hello, World!".as_bytes(); 184 | let result = String::from_sql_text(&Type::VARCHAR, sql_text).unwrap(); 185 | assert_eq!(result, "Hello, World!"); 186 | } 187 | 188 | #[test] 189 | fn test_from_sql_text_for_i32() { 190 | let sql_text = "42".as_bytes(); 191 | let result = i32::from_sql_text(&Type::VARCHAR, sql_text).unwrap(); 192 | assert_eq!(result, 42); 193 | } 194 | 195 | #[test] 196 | fn test_from_sql_text_for_i32_invalid() { 197 | let sql_text = "not_a_number".as_bytes(); 198 | let result = i32::from_sql_text(&Type::INT4, sql_text); 199 | assert!(result.is_err()); 200 | } 201 | 202 | #[test] 203 | fn test_from_sql_text_for_f64() { 204 | let sql_text = "3.14".as_bytes(); 205 | let result = f64::from_sql_text(&Type::FLOAT8, sql_text).unwrap(); 206 | assert_eq!(result, 3.14); 207 | } 208 | 209 | #[test] 210 | fn test_from_sql_text_for_f64_invalid() { 211 | let sql_text = "not_a_number".as_bytes(); 212 | let result = f64::from_sql_text(&Type::FLOAT8, sql_text); 213 | assert!(result.is_err()); 214 | } 215 | 216 | #[test] 217 | fn test_from_sql_text_for_bool() { 218 | let sql_text = "t".as_bytes(); 219 | let result = bool::from_sql_text(&Type::BOOL, sql_text).unwrap(); 220 | assert_eq!(result, true); 221 | 222 | let sql_text = "f".as_bytes(); 223 | let result = bool::from_sql_text(&Type::BOOL, sql_text).unwrap(); 224 | assert_eq!(result, false); 225 | } 226 | 227 | #[test] 228 | fn test_from_sql_text_for_bool_invalid() { 229 | let sql_text = "not_a_boolean".as_bytes(); 230 | let result = bool::from_sql_text(&Type::BOOL, sql_text); 231 | assert!(result.is_err()); 232 | } 233 | 234 | #[test] 235 | fn test_from_sql_text_for_option_string() { 236 | let sql_text = "Some text".as_bytes(); 237 | let result = Option::::from_sql_text(&Type::VARCHAR, sql_text).unwrap(); 238 | assert_eq!(result, Some("Some text".to_string())); 239 | 240 | let sql_text = "".as_bytes(); 241 | let result = Option::::from_sql_text(&Type::VARCHAR, sql_text).unwrap(); 242 | assert_eq!(result, None); 243 | } 244 | 245 | #[test] 246 | fn test_from_sql_text_for_option_i32() { 247 | let sql_text = "42".as_bytes(); 248 | let result = Option::::from_sql_text(&Type::INT4, sql_text).unwrap(); 249 | assert_eq!(result, Some(42)); 250 | 251 | let sql_text = "".as_bytes(); 252 | let result = Option::::from_sql_text(&Type::INT4, sql_text).unwrap(); 253 | assert_eq!(result, None); 254 | } 255 | 256 | #[test] 257 | fn test_from_sql_text_for_option_f64() { 258 | let sql_text = "3.14".as_bytes(); 259 | let result = Option::::from_sql_text(&Type::FLOAT8, sql_text).unwrap(); 260 | assert_eq!(result, Some(3.14)); 261 | 262 | let sql_text = "".as_bytes(); 263 | let result = Option::::from_sql_text(&Type::FLOAT8, sql_text).unwrap(); 264 | assert_eq!(result, None); 265 | } 266 | 267 | #[test] 268 | fn test_from_sql_text_for_option_bool() { 269 | let sql_text = "t".as_bytes(); 270 | let result = Option::::from_sql_text(&Type::BOOL, sql_text).unwrap(); 271 | assert_eq!(result, Some(true)); 272 | 273 | let sql_text = "".as_bytes(); 274 | let result = Option::::from_sql_text(&Type::BOOL, sql_text).unwrap(); 275 | assert_eq!(result, None); 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /src/types/mod.rs: -------------------------------------------------------------------------------- 1 | mod from_sql_text; 2 | mod to_sql_text; 3 | 4 | pub use from_sql_text::FromSqlText; 5 | pub use to_sql_text::ToSqlText; 6 | -------------------------------------------------------------------------------- /src/types/to_sql_text.rs: -------------------------------------------------------------------------------- 1 | use std::time::SystemTime; 2 | use std::{error::Error, fmt}; 3 | 4 | use bytes::{BufMut, BytesMut}; 5 | use chrono::offset::Utc; 6 | use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone}; 7 | use lazy_regex::{lazy_regex, Lazy, Regex}; 8 | use postgres_types::{IsNull, Kind, Type, WrongType}; 9 | use rust_decimal::Decimal; 10 | 11 | pub static QUOTE_CHECK: Lazy = lazy_regex!(r#"^$|["{},\\\s]|^null$"#i); 12 | pub static QUOTE_ESCAPE: Lazy = lazy_regex!(r#"(["\\])"#); 13 | 14 | pub trait ToSqlText: fmt::Debug { 15 | /// Converts value to text format of Postgres type. 16 | /// 17 | /// This trait is modelled after `ToSql` from postgres-types, which is 18 | /// for binary encoding. 19 | fn to_sql_text( 20 | &self, 21 | ty: &Type, 22 | out: &mut BytesMut, 23 | ) -> Result> 24 | where 25 | Self: Sized; 26 | } 27 | 28 | impl ToSqlText for &T 29 | where 30 | T: ToSqlText, 31 | { 32 | fn to_sql_text( 33 | &self, 34 | ty: &Type, 35 | out: &mut BytesMut, 36 | ) -> Result> { 37 | (*self).to_sql_text(ty, out) 38 | } 39 | } 40 | 41 | impl ToSqlText for Option { 42 | fn to_sql_text( 43 | &self, 44 | ty: &Type, 45 | out: &mut BytesMut, 46 | ) -> Result> { 47 | match *self { 48 | Some(ref val) => val.to_sql_text(ty, out), 49 | None => Ok(IsNull::Yes), 50 | } 51 | } 52 | } 53 | 54 | impl ToSqlText for bool { 55 | fn to_sql_text( 56 | &self, 57 | _ty: &Type, 58 | out: &mut BytesMut, 59 | ) -> Result> { 60 | if *self { 61 | out.put_slice(b"t"); 62 | } else { 63 | out.put_slice(b"f"); 64 | } 65 | Ok(IsNull::No) 66 | } 67 | } 68 | 69 | impl ToSqlText for String { 70 | fn to_sql_text( 71 | &self, 72 | ty: &Type, 73 | w: &mut BytesMut, 74 | ) -> Result> { 75 | <&str as ToSqlText>::to_sql_text(&&**self, ty, w) 76 | } 77 | } 78 | 79 | impl ToSqlText for &str { 80 | fn to_sql_text( 81 | &self, 82 | ty: &Type, 83 | w: &mut BytesMut, 84 | ) -> Result> { 85 | let quote = matches!(ty.kind(), Kind::Array(_)) && QUOTE_CHECK.is_match(self); 86 | 87 | if quote { 88 | w.put_u8(b'"'); 89 | w.put_slice(QUOTE_ESCAPE.replace_all(self, r#"\$1"#).as_bytes()); 90 | w.put_u8(b'"'); 91 | } else { 92 | w.put_slice(self.as_bytes()); 93 | } 94 | 95 | Ok(IsNull::No) 96 | } 97 | } 98 | 99 | macro_rules! impl_to_sql_text { 100 | ($t:ty) => { 101 | impl ToSqlText for $t { 102 | fn to_sql_text( 103 | &self, 104 | _ty: &Type, 105 | w: &mut BytesMut, 106 | ) -> Result> { 107 | w.put_slice(self.to_string().as_bytes()); 108 | Ok(IsNull::No) 109 | } 110 | } 111 | }; 112 | } 113 | 114 | impl_to_sql_text!(i8); 115 | impl_to_sql_text!(i16); 116 | impl_to_sql_text!(i32); 117 | impl_to_sql_text!(i64); 118 | impl_to_sql_text!(u32); 119 | impl_to_sql_text!(f32); 120 | impl_to_sql_text!(f64); 121 | impl_to_sql_text!(char); 122 | 123 | impl ToSqlText for &[u8] { 124 | fn to_sql_text( 125 | &self, 126 | _ty: &Type, 127 | out: &mut BytesMut, 128 | ) -> Result> { 129 | out.put_slice(b"\\x"); 130 | out.put_slice(hex::encode(self).as_bytes()); 131 | Ok(IsNull::No) 132 | } 133 | } 134 | 135 | impl ToSqlText for Vec { 136 | fn to_sql_text( 137 | &self, 138 | ty: &Type, 139 | out: &mut BytesMut, 140 | ) -> Result> { 141 | <&[u8] as ToSqlText>::to_sql_text(&&**self, ty, out) 142 | } 143 | } 144 | 145 | impl ToSqlText for [u8; N] { 146 | fn to_sql_text( 147 | &self, 148 | ty: &Type, 149 | out: &mut BytesMut, 150 | ) -> Result> { 151 | <&[u8] as ToSqlText>::to_sql_text(&&self[..], ty, out) 152 | } 153 | } 154 | 155 | impl ToSqlText for SystemTime { 156 | fn to_sql_text( 157 | &self, 158 | _ty: &Type, 159 | out: &mut BytesMut, 160 | ) -> Result> { 161 | let datetime: DateTime = DateTime::::from(*self); 162 | let fmt = datetime.format("%Y-%m-%d %H:%M:%S%.6f").to_string(); 163 | out.put_slice(fmt.as_bytes()); 164 | Ok(IsNull::No) 165 | } 166 | } 167 | 168 | impl ToSqlText for DateTime 169 | where 170 | Tz::Offset: std::fmt::Display, 171 | { 172 | fn to_sql_text( 173 | &self, 174 | ty: &Type, 175 | out: &mut BytesMut, 176 | ) -> Result> { 177 | let fmt = match *ty { 178 | Type::TIMESTAMP | Type::TIMESTAMP_ARRAY => "%Y-%m-%d %H:%M:%S%.6f", 179 | Type::TIMESTAMPTZ | Type::TIMESTAMPTZ_ARRAY => "%Y-%m-%d %H:%M:%S%.6f%:::z", 180 | Type::DATE | Type::DATE_ARRAY => "%Y-%m-%d", 181 | Type::TIME | Type::TIME_ARRAY => "%H:%M:%S%.6f", 182 | Type::TIMETZ | Type::TIMETZ_ARRAY => "%H:%M:%S%.6f%:::z", 183 | _ => Err(Box::new(WrongType::new::>(ty.clone())))?, 184 | }; 185 | out.put_slice(self.format(fmt).to_string().as_bytes()); 186 | Ok(IsNull::No) 187 | } 188 | } 189 | 190 | impl ToSqlText for NaiveDateTime { 191 | fn to_sql_text( 192 | &self, 193 | ty: &Type, 194 | out: &mut BytesMut, 195 | ) -> Result> { 196 | let fmt = match *ty { 197 | Type::TIMESTAMP | Type::TIMESTAMP_ARRAY => "%Y-%m-%d %H:%M:%S%.6f", 198 | Type::DATE | Type::DATE_ARRAY => "%Y-%m-%d", 199 | Type::TIME | Type::TIME_ARRAY => "%H:%M:%S%.6f", 200 | _ => Err(Box::new(WrongType::new::(ty.clone())))?, 201 | }; 202 | out.put_slice(self.format(fmt).to_string().as_bytes()); 203 | Ok(IsNull::No) 204 | } 205 | } 206 | 207 | impl ToSqlText for NaiveDate { 208 | fn to_sql_text( 209 | &self, 210 | ty: &Type, 211 | out: &mut BytesMut, 212 | ) -> Result> { 213 | let fmt = match *ty { 214 | Type::DATE | Type::DATE_ARRAY => self.format("%Y-%m-%d").to_string(), 215 | _ => Err(Box::new(WrongType::new::(ty.clone())))?, 216 | }; 217 | 218 | out.put_slice(fmt.as_bytes()); 219 | Ok(IsNull::No) 220 | } 221 | } 222 | 223 | impl ToSqlText for NaiveTime { 224 | fn to_sql_text( 225 | &self, 226 | ty: &Type, 227 | out: &mut BytesMut, 228 | ) -> Result> { 229 | let fmt = match *ty { 230 | Type::TIME | Type::TIME_ARRAY => self.format("%H:%M:%S%.6f").to_string(), 231 | _ => Err(Box::new(WrongType::new::(ty.clone())))?, 232 | }; 233 | out.put_slice(fmt.as_bytes()); 234 | Ok(IsNull::No) 235 | } 236 | } 237 | 238 | impl ToSqlText for Decimal { 239 | fn to_sql_text( 240 | &self, 241 | ty: &Type, 242 | out: &mut BytesMut, 243 | ) -> Result> 244 | where 245 | Self: Sized, 246 | { 247 | let fmt = match *ty { 248 | Type::NUMERIC | Type::NUMERIC_ARRAY => self.to_string(), 249 | _ => Err(Box::new(WrongType::new::(ty.clone())))?, 250 | }; 251 | 252 | out.put_slice(fmt.as_bytes()); 253 | Ok(IsNull::No) 254 | } 255 | } 256 | 257 | impl ToSqlText for &[T] { 258 | fn to_sql_text( 259 | &self, 260 | ty: &Type, 261 | out: &mut BytesMut, 262 | ) -> Result> { 263 | out.put_slice(b"{"); 264 | for (i, val) in self.iter().enumerate() { 265 | if i > 0 { 266 | out.put_slice(b","); 267 | } 268 | // put NULL for null value in array 269 | if let IsNull::Yes = val.to_sql_text(ty, out)? { 270 | out.put_slice(b"NULL"); 271 | } 272 | } 273 | out.put_slice(b"}"); 274 | Ok(IsNull::No) 275 | } 276 | } 277 | 278 | impl ToSqlText for Vec { 279 | fn to_sql_text( 280 | &self, 281 | ty: &Type, 282 | out: &mut BytesMut, 283 | ) -> Result> { 284 | <&[T] as ToSqlText>::to_sql_text(&&**self, ty, out) 285 | } 286 | } 287 | 288 | impl ToSqlText for [T; N] { 289 | fn to_sql_text( 290 | &self, 291 | ty: &Type, 292 | out: &mut BytesMut, 293 | ) -> Result> { 294 | <&[T] as ToSqlText>::to_sql_text(&&self[..], ty, out) 295 | } 296 | } 297 | 298 | #[cfg(test)] 299 | mod test { 300 | use super::*; 301 | use chrono::offset::FixedOffset; 302 | 303 | #[test] 304 | fn test_date_time_format() { 305 | let date = NaiveDate::from_ymd_opt(2023, 3, 5).unwrap(); 306 | let mut buf = BytesMut::new(); 307 | date.to_sql_text(&Type::DATE, &mut buf).unwrap(); 308 | assert_eq!("2023-03-05", String::from_utf8_lossy(buf.freeze().as_ref())); 309 | 310 | let date = NaiveDate::from_ymd_opt(2023, 3, 5).unwrap(); 311 | let mut buf = BytesMut::new(); 312 | assert!(date.to_sql_text(&Type::INT8, &mut buf).is_err()); 313 | 314 | let date = NaiveDateTime::new( 315 | NaiveDate::from_ymd_opt(2023, 3, 5).unwrap(), 316 | NaiveTime::from_hms_opt(10, 20, 00).unwrap(), 317 | ) 318 | .and_local_timezone(FixedOffset::east_opt(8 * 3600).unwrap()) 319 | .unwrap(); 320 | 321 | let mut buf = BytesMut::new(); 322 | date.to_sql_text(&Type::TIMESTAMPTZ, &mut buf).unwrap(); 323 | // format: 2023-02-01 22:31:49.479895+08 324 | assert_eq!( 325 | "2023-03-05 10:20:00.000000+08", 326 | String::from_utf8_lossy(buf.freeze().as_ref()) 327 | ); 328 | } 329 | 330 | #[test] 331 | fn test_null() { 332 | let data = vec![None::, Some(8)]; 333 | let mut buf = BytesMut::new(); 334 | data.to_sql_text(&Type::INT2, &mut buf).unwrap(); 335 | assert_eq!("{NULL,8}", String::from_utf8_lossy(buf.freeze().as_ref())); 336 | } 337 | 338 | #[test] 339 | fn test_bool() { 340 | let yes = true; 341 | let no = false; 342 | 343 | let mut buf = BytesMut::new(); 344 | yes.to_sql_text(&Type::BOOL, &mut buf).unwrap(); 345 | assert_eq!("t", String::from_utf8_lossy(buf.freeze().as_ref())); 346 | 347 | let mut buf = BytesMut::new(); 348 | no.to_sql_text(&Type::BOOL, &mut buf).unwrap(); 349 | assert_eq!("f", String::from_utf8_lossy(buf.freeze().as_ref())); 350 | } 351 | 352 | #[test] 353 | fn test_array() { 354 | let date = &[ 355 | NaiveDate::from_ymd_opt(2023, 3, 5).unwrap(), 356 | NaiveDate::from_ymd_opt(2023, 3, 6).unwrap(), 357 | ]; 358 | let mut buf = BytesMut::new(); 359 | date.to_sql_text(&Type::DATE_ARRAY, &mut buf).unwrap(); 360 | assert_eq!( 361 | "{2023-03-05,2023-03-06}", 362 | String::from_utf8_lossy(buf.freeze().as_ref()) 363 | ); 364 | 365 | let chars = &[ 366 | "{", "abc", "}", "\"", "", "a,b", "null", "NULL", "NULL!", "\\", " ", "\"\"", 367 | ]; 368 | let mut buf = BytesMut::new(); 369 | chars.to_sql_text(&Type::VARCHAR_ARRAY, &mut buf).unwrap(); 370 | assert_eq!( 371 | r#"{"{",abc,"}","\"","","a,b","null","NULL",NULL!,"\\"," ","\"\""}"#, 372 | String::from_utf8_lossy(buf.freeze().as_ref()) 373 | ); 374 | } 375 | } 376 | -------------------------------------------------------------------------------- /tests-integration/go/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "database/sql" 6 | _ "github.com/lib/pq" 7 | ) 8 | 9 | type result struct { 10 | id int 11 | name string 12 | date string 13 | isOk bool 14 | } 15 | 16 | func main() { 17 | conninfo := "host=127.0.0.1 port=5432 user=tom password=pencil dbname=localdb" 18 | db, err := sql.Open("postgres", conninfo) 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | defer db.Close() 23 | 24 | _, err = db.Exec("INSERT INTO testtable VALUES (1)") 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | 29 | rows, err := db.Query("SELECT * FROM testtable") 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | 34 | for rows.Next() { 35 | var r result 36 | rows.Scan( & r.id, & r.name, & r.date, & r.isOk) 37 | log.Printf("%#v", r) 38 | } 39 | 40 | rows, err = db.Query("SELECT * FROM testtable where id = ?", 1) 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | 45 | for rows.Next() { 46 | var r result 47 | rows.Scan( & r.id, & r.name, & r.date, & r.isOk) 48 | log.Printf("%#v", r) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tests-integration/go/go.mod: -------------------------------------------------------------------------------- 1 | module pgwire.com/client 2 | 3 | go 1.21.6 4 | 5 | require github.com/lib/pq v1.10.9 6 | -------------------------------------------------------------------------------- /tests-integration/go/go.sum: -------------------------------------------------------------------------------- 1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 3 | -------------------------------------------------------------------------------- /tests-integration/jdbc/test.bb: -------------------------------------------------------------------------------- 1 | (require '[babashka.pods :as pods]) 2 | (pods/load-pod 'org.babashka/postgresql "0.1.1") 3 | 4 | (require '[pod.babashka.postgresql :as pg]) 5 | 6 | (def db {:dbtype "postgresql" 7 | :host "127.0.0.1" 8 | :dbname "localdb" 9 | :user "postgres" 10 | :password "pencil" 11 | :port 5432}) 12 | 13 | (println (pg/execute! db ["INSERT INTO testtable VALUES (1)"])) 14 | 15 | (println (pg/execute! db ["SELECT * FROM testable"])) 16 | 17 | (println (pg/execute! db ["SELECT * FROM testable WHERE id = ?" 1])) 18 | -------------------------------------------------------------------------------- /tests-integration/nodejs/index.js: -------------------------------------------------------------------------------- 1 | const { strict } = require("node:assert"); 2 | 3 | const { Client } = require("pg"); 4 | const client = new Client({ 5 | host: "127.0.0.1", 6 | port: 5432, 7 | user: "tom", 8 | password: "pencil", 9 | database: "localdb", 10 | }); 11 | 12 | async function run() { 13 | await client.connect(); 14 | 15 | const res1 = await client.query("INSERT INTO testable VALUE (1)"); 16 | console.log(res1.rowCount); 17 | 18 | const res2 = await client.query("SELECT * FROM testtable"); 19 | console.log(res2.rows); 20 | 21 | const res3 = await client.query( 22 | "SELECT * FROM testtable WHERE id = $1::int", 23 | [1] 24 | ); 25 | console.log(res3.rows); 26 | strict.equal(res3.rows[0].id, 0); 27 | strict.equal(res3.rows[0].name, "Tom"); 28 | await client.end(); 29 | } 30 | 31 | run(); 32 | -------------------------------------------------------------------------------- /tests-integration/nodejs/package-lock.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "pgtest", 3 | "version": "1.0.0", 4 | "lockfileVersion": 2, 5 | "requires": true, 6 | "packages": { 7 | "": { 8 | "name": "pgtest", 9 | "version": "1.0.0", 10 | "license": "ISC", 11 | "dependencies": { 12 | "pg": "^8.9.0" 13 | } 14 | }, 15 | "node_modules/buffer-writer": { 16 | "version": "2.0.0", 17 | "resolved": "https://registry.npmjs.org/buffer-writer/-/buffer-writer-2.0.0.tgz", 18 | "integrity": "sha512-a7ZpuTZU1TRtnwyCNW3I5dc0wWNC3VR9S++Ewyk2HHZdrO3CQJqSpd+95Us590V6AL7JqUAH2IwZ/398PmNFgw==", 19 | "engines": { 20 | "node": ">=4" 21 | } 22 | }, 23 | "node_modules/packet-reader": { 24 | "version": "1.0.0", 25 | "resolved": "https://registry.npmjs.org/packet-reader/-/packet-reader-1.0.0.tgz", 26 | "integrity": "sha512-HAKu/fG3HpHFO0AA8WE8q2g+gBJaZ9MG7fcKk+IJPLTGAD6Psw4443l+9DGRbOIh3/aXr7Phy0TjilYivJo5XQ==" 27 | }, 28 | "node_modules/pg": { 29 | "version": "8.9.0", 30 | "resolved": "https://registry.npmjs.org/pg/-/pg-8.9.0.tgz", 31 | "integrity": "sha512-ZJM+qkEbtOHRuXjmvBtOgNOXOtLSbxiMiUVMgE4rV6Zwocy03RicCVvDXgx8l4Biwo8/qORUnEqn2fdQzV7KCg==", 32 | "dependencies": { 33 | "buffer-writer": "2.0.0", 34 | "packet-reader": "1.0.0", 35 | "pg-connection-string": "^2.5.0", 36 | "pg-pool": "^3.5.2", 37 | "pg-protocol": "^1.6.0", 38 | "pg-types": "^2.1.0", 39 | "pgpass": "1.x" 40 | }, 41 | "engines": { 42 | "node": ">= 8.0.0" 43 | }, 44 | "peerDependencies": { 45 | "pg-native": ">=3.0.1" 46 | }, 47 | "peerDependenciesMeta": { 48 | "pg-native": { 49 | "optional": true 50 | } 51 | } 52 | }, 53 | "node_modules/pg-connection-string": { 54 | "version": "2.5.0", 55 | "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.5.0.tgz", 56 | "integrity": "sha512-r5o/V/ORTA6TmUnyWZR9nCj1klXCO2CEKNRlVuJptZe85QuhFayC7WeMic7ndayT5IRIR0S0xFxFi2ousartlQ==" 57 | }, 58 | "node_modules/pg-int8": { 59 | "version": "1.0.1", 60 | "resolved": "https://registry.npmjs.org/pg-int8/-/pg-int8-1.0.1.tgz", 61 | "integrity": "sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==", 62 | "engines": { 63 | "node": ">=4.0.0" 64 | } 65 | }, 66 | "node_modules/pg-pool": { 67 | "version": "3.5.2", 68 | "resolved": "https://registry.npmjs.org/pg-pool/-/pg-pool-3.5.2.tgz", 69 | "integrity": "sha512-His3Fh17Z4eg7oANLob6ZvH8xIVen3phEZh2QuyrIl4dQSDVEabNducv6ysROKpDNPSD+12tONZVWfSgMvDD9w==", 70 | "peerDependencies": { 71 | "pg": ">=8.0" 72 | } 73 | }, 74 | "node_modules/pg-protocol": { 75 | "version": "1.6.0", 76 | "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.6.0.tgz", 77 | "integrity": "sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==" 78 | }, 79 | "node_modules/pg-types": { 80 | "version": "2.2.0", 81 | "resolved": "https://registry.npmjs.org/pg-types/-/pg-types-2.2.0.tgz", 82 | "integrity": "sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==", 83 | "dependencies": { 84 | "pg-int8": "1.0.1", 85 | "postgres-array": "~2.0.0", 86 | "postgres-bytea": "~1.0.0", 87 | "postgres-date": "~1.0.4", 88 | "postgres-interval": "^1.1.0" 89 | }, 90 | "engines": { 91 | "node": ">=4" 92 | } 93 | }, 94 | "node_modules/pgpass": { 95 | "version": "1.0.5", 96 | "resolved": "https://registry.npmjs.org/pgpass/-/pgpass-1.0.5.tgz", 97 | "integrity": "sha512-FdW9r/jQZhSeohs1Z3sI1yxFQNFvMcnmfuj4WBMUTxOrAyLMaTcE1aAMBiTlbMNaXvBCQuVi0R7hd8udDSP7ug==", 98 | "dependencies": { 99 | "split2": "^4.1.0" 100 | } 101 | }, 102 | "node_modules/postgres-array": { 103 | "version": "2.0.0", 104 | "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz", 105 | "integrity": "sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==", 106 | "engines": { 107 | "node": ">=4" 108 | } 109 | }, 110 | "node_modules/postgres-bytea": { 111 | "version": "1.0.0", 112 | "resolved": "https://registry.npmjs.org/postgres-bytea/-/postgres-bytea-1.0.0.tgz", 113 | "integrity": "sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==", 114 | "engines": { 115 | "node": ">=0.10.0" 116 | } 117 | }, 118 | "node_modules/postgres-date": { 119 | "version": "1.0.7", 120 | "resolved": "https://registry.npmjs.org/postgres-date/-/postgres-date-1.0.7.tgz", 121 | "integrity": "sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==", 122 | "engines": { 123 | "node": ">=0.10.0" 124 | } 125 | }, 126 | "node_modules/postgres-interval": { 127 | "version": "1.2.0", 128 | "resolved": "https://registry.npmjs.org/postgres-interval/-/postgres-interval-1.2.0.tgz", 129 | "integrity": "sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==", 130 | "dependencies": { 131 | "xtend": "^4.0.0" 132 | }, 133 | "engines": { 134 | "node": ">=0.10.0" 135 | } 136 | }, 137 | "node_modules/split2": { 138 | "version": "4.1.0", 139 | "resolved": "https://registry.npmjs.org/split2/-/split2-4.1.0.tgz", 140 | "integrity": "sha512-VBiJxFkxiXRlUIeyMQi8s4hgvKCSjtknJv/LVYbrgALPwf5zSKmEwV9Lst25AkvMDnvxODugjdl6KZgwKM1WYQ==", 141 | "engines": { 142 | "node": ">= 10.x" 143 | } 144 | }, 145 | "node_modules/xtend": { 146 | "version": "4.0.2", 147 | "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", 148 | "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", 149 | "engines": { 150 | "node": ">=0.4" 151 | } 152 | } 153 | }, 154 | "dependencies": { 155 | "buffer-writer": { 156 | "version": "2.0.0", 157 | "resolved": "https://registry.npmjs.org/buffer-writer/-/buffer-writer-2.0.0.tgz", 158 | "integrity": "sha512-a7ZpuTZU1TRtnwyCNW3I5dc0wWNC3VR9S++Ewyk2HHZdrO3CQJqSpd+95Us590V6AL7JqUAH2IwZ/398PmNFgw==" 159 | }, 160 | "packet-reader": { 161 | "version": "1.0.0", 162 | "resolved": "https://registry.npmjs.org/packet-reader/-/packet-reader-1.0.0.tgz", 163 | "integrity": "sha512-HAKu/fG3HpHFO0AA8WE8q2g+gBJaZ9MG7fcKk+IJPLTGAD6Psw4443l+9DGRbOIh3/aXr7Phy0TjilYivJo5XQ==" 164 | }, 165 | "pg": { 166 | "version": "8.9.0", 167 | "resolved": "https://registry.npmjs.org/pg/-/pg-8.9.0.tgz", 168 | "integrity": "sha512-ZJM+qkEbtOHRuXjmvBtOgNOXOtLSbxiMiUVMgE4rV6Zwocy03RicCVvDXgx8l4Biwo8/qORUnEqn2fdQzV7KCg==", 169 | "requires": { 170 | "buffer-writer": "2.0.0", 171 | "packet-reader": "1.0.0", 172 | "pg-connection-string": "^2.5.0", 173 | "pg-pool": "^3.5.2", 174 | "pg-protocol": "^1.6.0", 175 | "pg-types": "^2.1.0", 176 | "pgpass": "1.x" 177 | } 178 | }, 179 | "pg-connection-string": { 180 | "version": "2.5.0", 181 | "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.5.0.tgz", 182 | "integrity": "sha512-r5o/V/ORTA6TmUnyWZR9nCj1klXCO2CEKNRlVuJptZe85QuhFayC7WeMic7ndayT5IRIR0S0xFxFi2ousartlQ==" 183 | }, 184 | "pg-int8": { 185 | "version": "1.0.1", 186 | "resolved": "https://registry.npmjs.org/pg-int8/-/pg-int8-1.0.1.tgz", 187 | "integrity": "sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==" 188 | }, 189 | "pg-pool": { 190 | "version": "3.5.2", 191 | "resolved": "https://registry.npmjs.org/pg-pool/-/pg-pool-3.5.2.tgz", 192 | "integrity": "sha512-His3Fh17Z4eg7oANLob6ZvH8xIVen3phEZh2QuyrIl4dQSDVEabNducv6ysROKpDNPSD+12tONZVWfSgMvDD9w==", 193 | "requires": {} 194 | }, 195 | "pg-protocol": { 196 | "version": "1.6.0", 197 | "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.6.0.tgz", 198 | "integrity": "sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==" 199 | }, 200 | "pg-types": { 201 | "version": "2.2.0", 202 | "resolved": "https://registry.npmjs.org/pg-types/-/pg-types-2.2.0.tgz", 203 | "integrity": "sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==", 204 | "requires": { 205 | "pg-int8": "1.0.1", 206 | "postgres-array": "~2.0.0", 207 | "postgres-bytea": "~1.0.0", 208 | "postgres-date": "~1.0.4", 209 | "postgres-interval": "^1.1.0" 210 | } 211 | }, 212 | "pgpass": { 213 | "version": "1.0.5", 214 | "resolved": "https://registry.npmjs.org/pgpass/-/pgpass-1.0.5.tgz", 215 | "integrity": "sha512-FdW9r/jQZhSeohs1Z3sI1yxFQNFvMcnmfuj4WBMUTxOrAyLMaTcE1aAMBiTlbMNaXvBCQuVi0R7hd8udDSP7ug==", 216 | "requires": { 217 | "split2": "^4.1.0" 218 | } 219 | }, 220 | "postgres-array": { 221 | "version": "2.0.0", 222 | "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz", 223 | "integrity": "sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==" 224 | }, 225 | "postgres-bytea": { 226 | "version": "1.0.0", 227 | "resolved": "https://registry.npmjs.org/postgres-bytea/-/postgres-bytea-1.0.0.tgz", 228 | "integrity": "sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==" 229 | }, 230 | "postgres-date": { 231 | "version": "1.0.7", 232 | "resolved": "https://registry.npmjs.org/postgres-date/-/postgres-date-1.0.7.tgz", 233 | "integrity": "sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==" 234 | }, 235 | "postgres-interval": { 236 | "version": "1.2.0", 237 | "resolved": "https://registry.npmjs.org/postgres-interval/-/postgres-interval-1.2.0.tgz", 238 | "integrity": "sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==", 239 | "requires": { 240 | "xtend": "^4.0.0" 241 | } 242 | }, 243 | "split2": { 244 | "version": "4.1.0", 245 | "resolved": "https://registry.npmjs.org/split2/-/split2-4.1.0.tgz", 246 | "integrity": "sha512-VBiJxFkxiXRlUIeyMQi8s4hgvKCSjtknJv/LVYbrgALPwf5zSKmEwV9Lst25AkvMDnvxODugjdl6KZgwKM1WYQ==" 247 | }, 248 | "xtend": { 249 | "version": "4.0.2", 250 | "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", 251 | "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==" 252 | } 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /tests-integration/nodejs/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "pgtest", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "node index.js" 8 | }, 9 | "author": "", 10 | "license": "ISC", 11 | "dependencies": { 12 | "pg": "^8.9.0" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /tests-integration/python/client2.py: -------------------------------------------------------------------------------- 1 | import psycopg2 2 | 3 | conn = psycopg2.connect("host=127.0.0.1 port=5432 user=tom password=pencil dbname=localdb") 4 | conn.autocommit = True 5 | 6 | with conn.cursor() as cur: 7 | cur.execute("INSERT INTO testtable VALUES (1)") 8 | print(cur.statusmessage) 9 | 10 | with conn.cursor() as cur: 11 | cur.execute("SELECT * FROM testtable") 12 | print(cur.fetchall()) 13 | -------------------------------------------------------------------------------- /tests-integration/python/client3.py: -------------------------------------------------------------------------------- 1 | import psycopg 2 | 3 | conn = psycopg.connect("host=127.0.0.1 port=5432 user=tom password=pencil dbname=localdb") 4 | conn.autocommit = True 5 | 6 | with conn.cursor() as cur: 7 | cur.execute("INSERT INTO testtable VALUES (1)") 8 | print(cur.statusmessage) 9 | 10 | with conn.cursor() as cur: 11 | cur.execute("SELECT * FROM testtable") 12 | print(cur.fetchall()) 13 | 14 | with conn.cursor() as cur: 15 | cur.execute("SELECT * FROM testtable WHERE id = %s", [1]) 16 | print(cur.fetchall()) 17 | -------------------------------------------------------------------------------- /tests-integration/rust-client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-client" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | tokio = { version = "1", features = ["full"] } 11 | openssl = "0.10" 12 | postgres = { version = "0.19.10" } 13 | postgres-openssl = { version = "0.5.1" } 14 | -------------------------------------------------------------------------------- /tests-integration/rust-client/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::time::SystemTime; 2 | 3 | use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; 4 | use postgres::{Client, SimpleQueryMessage}; 5 | use postgres_openssl::MakeTlsConnector; 6 | 7 | fn main() { 8 | let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); 9 | builder.set_verify(SslVerifyMode::NONE); 10 | postgres_openssl::set_postgresql_alpn(&mut builder).unwrap(); 11 | let connector = MakeTlsConnector::new(builder.build()); 12 | let mut client = Client::connect( 13 | "host=localhost port=5432 user=postgres password=pencil dbname=localdb keepalives=0 sslmode=require sslnegotiation=direct", 14 | connector, 15 | ) 16 | .unwrap(); 17 | 18 | let results = client.simple_query("SELECT * FROM testtable").unwrap(); 19 | for row in results { 20 | if let SimpleQueryMessage::Row(row) = row { 21 | println!("{:?}", row.get(0)); 22 | println!("{:?}", row.get(1)); 23 | println!("{:?}", row.get(2)); 24 | } 25 | } 26 | 27 | for row in client 28 | .query("SELECT * FROM testtable WHERE id = ?", &[&1]) 29 | .unwrap() 30 | { 31 | println!("{:?}", row.get::>(0)); 32 | println!("{:?}", row.get::>(1)); 33 | println!("{:?}", row.get::>(2)); 34 | } 35 | 36 | client 37 | .simple_query("INSERT INTO testtable VALUES (1)") 38 | .unwrap(); 39 | 40 | client.close().unwrap(); 41 | } 42 | -------------------------------------------------------------------------------- /tests-integration/test-server/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "test-server" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | pgwire = { path = "../../", features = ["scram"] } 11 | async-trait = "0.1" 12 | futures = "0.3" 13 | tokio = { version = "1", features = ["full"] } 14 | tokio-rustls = { version = "0.26.2", default-features = false, features = ["logging", "tls12"]} 15 | rustls-pemfile = "2.0" 16 | rustls-pki-types = "1.0" 17 | -------------------------------------------------------------------------------- /tests-integration/test-server/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufReader, Error as IOError, ErrorKind}; 3 | use std::sync::Arc; 4 | use std::time::{Duration, SystemTime}; 5 | 6 | use async_trait::async_trait; 7 | use futures::stream; 8 | use futures::StreamExt; 9 | use rustls_pemfile::{certs, pkcs8_private_keys}; 10 | use rustls_pki_types::{CertificateDer, PrivateKeyDer}; 11 | use tokio_rustls::rustls::ServerConfig; 12 | use tokio_rustls::TlsAcceptor; 13 | 14 | use pgwire::api::auth::scram::{gen_salted_password, SASLScramAuthStartupHandler}; 15 | use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; 16 | use pgwire::api::copy::NoopCopyHandler; 17 | use pgwire::api::portal::{Format, Portal}; 18 | use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; 19 | use pgwire::api::results::{ 20 | DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, 21 | Response, Tag, 22 | }; 23 | use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; 24 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; 25 | use pgwire::error::PgWireResult; 26 | use pgwire::tokio::process_socket; 27 | use tokio::net::TcpListener; 28 | 29 | const ITERATIONS: usize = 4096; 30 | struct DummyAuthSource; 31 | 32 | #[async_trait] 33 | impl AuthSource for DummyAuthSource { 34 | async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult { 35 | println!("login info: {:?}", login_info); 36 | 37 | let password = "pencil"; 38 | let salt = vec![0, 20, 40, 80]; 39 | 40 | let hash_password = gen_salted_password(password, salt.as_ref(), ITERATIONS); 41 | Ok(Password::new(Some(salt), hash_password)) 42 | } 43 | } 44 | 45 | #[derive(Default)] 46 | struct DummyDatabase { 47 | query_parser: Arc, 48 | } 49 | 50 | impl DummyDatabase { 51 | fn schema(&self, format: &Format) -> Vec { 52 | let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, format.format_for(0)); 53 | let f2 = FieldInfo::new( 54 | "name".into(), 55 | None, 56 | None, 57 | Type::VARCHAR, 58 | format.format_for(1), 59 | ); 60 | let f3 = FieldInfo::new( 61 | "ts".into(), 62 | None, 63 | None, 64 | Type::TIMESTAMP, 65 | format.format_for(2), 66 | ); 67 | let f4 = FieldInfo::new( 68 | "signed".into(), 69 | None, 70 | None, 71 | Type::BOOL, 72 | format.format_for(3), 73 | ); 74 | let f5 = FieldInfo::new("data".into(), None, None, Type::BYTEA, format.format_for(4)); 75 | vec![f1, f2, f3, f4, f5] 76 | } 77 | } 78 | 79 | #[async_trait] 80 | impl SimpleQueryHandler for DummyDatabase { 81 | async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> 82 | where 83 | C: ClientInfo + Unpin + Send + Sync, 84 | { 85 | println!("simple query: {:?}", query); 86 | if query.starts_with("SELECT") { 87 | let schema = Arc::new(self.schema(&Format::UnifiedText)); 88 | let schema_ref = schema.clone(); 89 | let data = vec![ 90 | ( 91 | Some(0), 92 | Some("Tom"), 93 | Some("2023-02-01 22:27:25.042674"), 94 | Some(true), 95 | Some("tomcat".as_bytes()), 96 | ), 97 | ( 98 | Some(1), 99 | Some("Jerry"), 100 | Some("2023-02-01 22:27:42.165585"), 101 | Some(false), 102 | Some("".as_bytes()), 103 | ), 104 | (Some(2), None, None, None, None), 105 | ]; 106 | let data_row_stream = stream::iter(data.into_iter()).map(move |r| { 107 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 108 | 109 | encoder.encode_field(&r.0)?; 110 | encoder.encode_field(&r.1)?; 111 | encoder.encode_field(&r.2)?; 112 | encoder.encode_field(&r.3)?; 113 | encoder.encode_field(&r.4)?; 114 | 115 | encoder.finish() 116 | }); 117 | 118 | Ok(vec![Response::Query(QueryResponse::new( 119 | schema, 120 | data_row_stream, 121 | ))]) 122 | } else { 123 | Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))]) 124 | } 125 | } 126 | } 127 | 128 | #[async_trait] 129 | impl ExtendedQueryHandler for DummyDatabase { 130 | type Statement = String; 131 | type QueryParser = NoopQueryParser; 132 | 133 | fn query_parser(&self) -> Arc { 134 | self.query_parser.clone() 135 | } 136 | 137 | async fn do_query<'a, C>( 138 | &self, 139 | _client: &mut C, 140 | portal: &Portal, 141 | _max_rows: usize, 142 | ) -> PgWireResult> 143 | where 144 | C: ClientInfo + Unpin + Send + Sync, 145 | { 146 | let query = &portal.statement.statement; 147 | println!("extended query: {:?}", query); 148 | if query.starts_with("SELECT") { 149 | let data = vec![ 150 | ( 151 | Some(0), 152 | Some("Tom"), 153 | Some(SystemTime::now()), 154 | Some(true), 155 | Some("tomcat".as_bytes()), 156 | ), 157 | ( 158 | Some(1), 159 | Some("Jerry"), 160 | Some(SystemTime::UNIX_EPOCH + Duration::from_secs(86400 * 5000)), 161 | Some(false), 162 | Some("".as_bytes()), 163 | ), 164 | (Some(2), None, None, None, None), 165 | ]; 166 | let schema = Arc::new(self.schema(&portal.result_column_format)); 167 | let schema_ref = schema.clone(); 168 | let data_row_stream = stream::iter(data.into_iter()).map(move |r| { 169 | let mut encoder = DataRowEncoder::new(schema_ref.clone()); 170 | 171 | encoder.encode_field(&r.0)?; 172 | encoder.encode_field(&r.1)?; 173 | encoder.encode_field(&r.2)?; 174 | encoder.encode_field(&r.3)?; 175 | encoder.encode_field(&r.4)?; 176 | 177 | encoder.finish() 178 | }); 179 | 180 | Ok(Response::Query(QueryResponse::new(schema, data_row_stream))) 181 | } else { 182 | Ok(Response::Execution(Tag::new("OK").with_rows(1))) 183 | } 184 | } 185 | 186 | async fn do_describe_statement( 187 | &self, 188 | _client: &mut C, 189 | stmt: &StoredStatement, 190 | ) -> PgWireResult 191 | where 192 | C: ClientInfo + Unpin + Send + Sync, 193 | { 194 | println!("describe: {:?}", stmt); 195 | let param_types = vec![Type::INT4]; 196 | let schema = self.schema(&Format::UnifiedText); 197 | Ok(DescribeStatementResponse::new(param_types, schema)) 198 | } 199 | 200 | async fn do_describe_portal( 201 | &self, 202 | _client: &mut C, 203 | portal: &Portal, 204 | ) -> PgWireResult 205 | where 206 | C: ClientInfo + Unpin + Send + Sync, 207 | { 208 | println!("describe: {:?}", portal); 209 | let schema = self.schema(&portal.result_column_format); 210 | Ok(DescribePortalResponse::new(schema)) 211 | } 212 | } 213 | 214 | struct DummyDatabaseFactory(Arc); 215 | 216 | impl PgWireServerHandlers for DummyDatabaseFactory { 217 | type StartupHandler = 218 | SASLScramAuthStartupHandler; 219 | type SimpleQueryHandler = DummyDatabase; 220 | type ExtendedQueryHandler = DummyDatabase; 221 | type CopyHandler = NoopCopyHandler; 222 | type ErrorHandler = NoopErrorHandler; 223 | 224 | fn simple_query_handler(&self) -> Arc { 225 | self.0.clone() 226 | } 227 | 228 | fn extended_query_handler(&self) -> Arc { 229 | self.0.clone() 230 | } 231 | 232 | fn startup_handler(&self) -> Arc { 233 | let mut authenticator = SASLScramAuthStartupHandler::new( 234 | Arc::new(DummyAuthSource), 235 | Arc::new(DefaultServerParameterProvider::default()), 236 | ); 237 | authenticator.set_iterations(ITERATIONS); 238 | 239 | Arc::new(authenticator) 240 | } 241 | 242 | fn copy_handler(&self) -> Arc { 243 | Arc::new(NoopCopyHandler) 244 | } 245 | 246 | fn error_handler(&self) -> Arc { 247 | Arc::new(NoopErrorHandler) 248 | } 249 | } 250 | 251 | fn setup_tls() -> Result { 252 | let cert = certs(&mut BufReader::new(File::open( 253 | "../../examples/ssl/server.crt", 254 | )?)) 255 | .collect::, IOError>>()?; 256 | 257 | let key = pkcs8_private_keys(&mut BufReader::new(File::open( 258 | "../../examples/ssl/server.key", 259 | )?)) 260 | .map(|key| key.map(PrivateKeyDer::from)) 261 | .collect::, IOError>>()? 262 | .remove(0); 263 | 264 | let mut config = ServerConfig::builder() 265 | .with_no_client_auth() 266 | .with_single_cert(cert, key) 267 | .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?; 268 | 269 | config.alpn_protocols = vec![b"postgresql".to_vec()]; 270 | 271 | Ok(TlsAcceptor::from(Arc::new(config))) 272 | } 273 | 274 | #[tokio::main] 275 | pub async fn main() { 276 | let factory = Arc::new(DummyDatabaseFactory(Arc::new(DummyDatabase::default()))); 277 | 278 | let server_addr = "127.0.0.1:5432"; 279 | let tls_acceptor = setup_tls().unwrap(); 280 | let listener = TcpListener::bind(server_addr).await.unwrap(); 281 | println!("Listening to {}", server_addr); 282 | loop { 283 | let incoming_socket = listener.accept().await.unwrap(); 284 | let tls_acceptor_ref = tls_acceptor.clone(); 285 | let factory_ref = factory.clone(); 286 | 287 | tokio::spawn(async move { 288 | process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await 289 | }); 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /tests-integration/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | cd tests-integration 5 | 6 | ## start test server 7 | pushd test-server 8 | cargo build 9 | ../../target/debug/test-server & 10 | popd 11 | 12 | ## run rust-client 13 | pushd rust-client 14 | cargo run 15 | popd 16 | 17 | ## run python-clients 18 | pushd python 19 | python client2.py 20 | python client3.py 21 | popd 22 | 23 | ### jdbc 24 | pushd jdbc 25 | bb test.bb 26 | popd 27 | 28 | ### node 29 | pushd nodejs 30 | npm install 31 | npm run test 32 | popd 33 | 34 | ### golang 35 | pushd go 36 | go run client.go 37 | popd 38 | --------------------------------------------------------------------------------