├── .gitignore ├── .ocamlformat ├── CHANGES.md ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── README.md ├── dune ├── dune-project ├── postgres_async.opam ├── protocol ├── backend.ml ├── backend.mli ├── column_metadata.ml ├── column_metadata.mli ├── column_metadata_intf.ml ├── dune ├── frontend.ml ├── frontend.mli ├── import.ml ├── postgres_async_protocol.ml ├── shared.ml ├── shared.mli ├── types.ml └── types.mli ├── src ├── command_complete.ml ├── command_complete.mli ├── command_complete_intf.ml ├── dune ├── message_reading_intf.ml ├── or_pgasync_error.ml ├── or_pgasync_error.mli ├── pgasync_error.ml ├── pgasync_error.mli ├── postgres_async.ml ├── postgres_async.mli ├── postgres_async_intf.ml ├── query_sequencer.ml ├── query_sequencer.mli ├── row_handle.ml ├── row_handle.mli ├── ssl_mode.ml ├── ssl_mode.mli ├── string_escaping.ml └── string_escaping.mli └── test ├── dune ├── harness.ml ├── harness.mli ├── postgres_async_tests.ml ├── server-leaf_certificate.crt ├── server-leaf_certificate.pem ├── server-leaf_key.key ├── test_cancellation.ml ├── test_cancellation.mli ├── test_connect.ml ├── test_connect.mli ├── test_copy_in.ml ├── test_copy_in.mli ├── test_copy_out.ml ├── test_copy_out.mli ├── test_error_code.ml ├── test_error_code.mli ├── test_notify.ml ├── test_notify.mli ├── test_protocol_round_trip.ml ├── test_protocol_round_trip.mli ├── test_query.ml ├── test_query.mli ├── test_runtime_parameters.ml ├── test_runtime_parameters.mli ├── test_server_failure.ml ├── test_server_failure.mli ├── test_simple_query.ml ├── test_simple_query.mli ├── test_smoke.ml ├── test_smoke.mli ├── test_ssl.ml ├── test_ssl.mli ├── utils.ml └── utils.mli /.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | *.install 3 | *.merlin 4 | _opam 5 | 6 | -------------------------------------------------------------------------------- /.ocamlformat: -------------------------------------------------------------------------------- 1 | profile=janestreet 2 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | ## Release v0.17.0 2 | 3 | - `Postgres_async` now returns appropriate `sql_state_code`s for connection-related error states, such as connection unexpectedly closed. 4 | 5 | - `Postgres_async` gained implementation of the SimpleQuery message flow, which is currently exposed under `Postgres_async.Private`, as it is relatively new and not extensively tested. If you wish to try it, the main entry points are `simple_query` and `execute_simple`. 6 | 7 | - `Postgres_async.Private.pg_cancel` is a best-effort attempt to send out-of-bound message equivalent to invocation of `pg_cancel_backend()` for the given connection. 8 | 9 | - `Postgres_async.close` now accepts an optional `try_cancel_statement_before_close` 10 | parameter to try and issue out-of-band cancel request for the currently running 11 | statement (if any) before closing the connection. 12 | 13 | - For the `Postgres_async.copy_in_rows` and `String_escaping.Copy_in`, the `column_names` 14 | parameter type changed from `string array` to `string list`. 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | This repository contains open source software that is developed and 2 | maintained by [Jane Street][js]. 3 | 4 | Contributions to this project are welcome and should be submitted via 5 | GitHub pull requests. 6 | 7 | Signing contributions 8 | --------------------- 9 | 10 | We require that you sign your contributions. Your signature certifies 11 | that you wrote the patch or otherwise have the right to pass it on as 12 | an open-source patch. The rules are pretty simple: if you can certify 13 | the below (from [developercertificate.org][dco]): 14 | 15 | ``` 16 | Developer Certificate of Origin 17 | Version 1.1 18 | 19 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 20 | 1 Letterman Drive 21 | Suite D4700 22 | San Francisco, CA, 94129 23 | 24 | Everyone is permitted to copy and distribute verbatim copies of this 25 | license document, but changing it is not allowed. 26 | 27 | 28 | Developer's Certificate of Origin 1.1 29 | 30 | By making a contribution to this project, I certify that: 31 | 32 | (a) The contribution was created in whole or in part by me and I 33 | have the right to submit it under the open source license 34 | indicated in the file; or 35 | 36 | (b) The contribution is based upon previous work that, to the best 37 | of my knowledge, is covered under an appropriate open source 38 | license and I have the right under that license to submit that 39 | work with modifications, whether created in whole or in part 40 | by me, under the same open source license (unless I am 41 | permitted to submit under a different license), as indicated 42 | in the file; or 43 | 44 | (c) The contribution was provided directly to me by some other 45 | person who certified (a), (b) or (c) and I have not modified 46 | it. 47 | 48 | (d) I understand and agree that this project and the contribution 49 | are public and that a record of the contribution (including all 50 | personal information I submit with it, including my sign-off) is 51 | maintained indefinitely and may be redistributed consistent with 52 | this project or the open source license(s) involved. 53 | ``` 54 | 55 | Then you just add a line to every git commit message: 56 | 57 | ``` 58 | Signed-off-by: Joe Smith 59 | ``` 60 | 61 | Use your real name (sorry, no pseudonyms or anonymous contributions.) 62 | 63 | If you set your `user.name` and `user.email` git configs, you can sign 64 | your commit automatically with git commit -s. 65 | 66 | [dco]: http://developercertificate.org/ 67 | [js]: https://opensource.janestreet.com/ 68 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019--2025 Jane Street Group, LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | INSTALL_ARGS := $(if $(PREFIX),--prefix $(PREFIX),) 2 | 3 | default: 4 | dune build 5 | 6 | install: 7 | dune install $(INSTALL_ARGS) 8 | 9 | uninstall: 10 | dune uninstall $(INSTALL_ARGS) 11 | 12 | reinstall: uninstall install 13 | 14 | clean: 15 | dune clean 16 | 17 | .PHONY: default install uninstall reinstall clean 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Postgres_async 2 | -------------- 3 | `postgres_async` is an ocaml PostgreSQL client that implements the PostgreSQL protocol 4 | rather than binding to the libpq C library. It provides support for regular queries 5 | (including support for 'parameters': `SELECT * WHERE a = $1`) and `COPY IN` mode. The 6 | interface presented is minimal to keep the library simple for now, though in the future 7 | a layer on top may add convenience functions. 8 | 9 | To get started, have a look at the Postgres_async module interface in 10 | `src/postgresql_async.mli`. 11 | -------------------------------------------------------------------------------- /dune: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/janestreet/postgres_async/55d83400bae826efb90def299675e558bf37420e/dune -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 3.17) 2 | -------------------------------------------------------------------------------- /postgres_async.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | maintainer: "Jane Street developers" 3 | authors: ["Jane Street Group, LLC"] 4 | homepage: "https://github.com/janestreet/postgres_async" 5 | bug-reports: "https://github.com/janestreet/postgres_async/issues" 6 | dev-repo: "git+https://github.com/janestreet/postgres_async.git" 7 | doc: "https://ocaml.janestreet.com/ocaml-core/latest/doc/postgres_async/index.html" 8 | license: "MIT" 9 | build: [ 10 | ["dune" "build" "-p" name "-j" jobs] 11 | ] 12 | depends: [ 13 | "ocaml" {>= "5.1.0"} 14 | "async" 15 | "async_kernel" 16 | "async_ssl" 17 | "async_unix" 18 | "core" 19 | "core_kernel" 20 | "ppx_jane" 21 | "dune" {>= "3.17.0"} 22 | ] 23 | available: arch != "arm32" & arch != "x86_32" 24 | synopsis: "OCaml/async implementation of the postgres protocol (i.e., does not use C-bindings to libpq)" 25 | description: " 26 | postgres_async is an OCaml PostgreSQL client that implements the PostgreSQL 27 | protocol rather than binding to the libpq C library. It provides support for 28 | regular queries (including support for 'parameters': \"SELECT * WHERE a = $1\") 29 | and COPY IN mode. The interface presented is minimal to keep the library simple 30 | for now, though in the future a layer on top may add convenience functions. 31 | " 32 | -------------------------------------------------------------------------------- /protocol/backend.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | open! Import 4 | 5 | type constructor = 6 | | AuthenticationRequest 7 | | BackendKeyData 8 | | BindComplete 9 | | CloseComplete 10 | | CommandComplete 11 | | CopyData 12 | | CopyDone 13 | | CopyInResponse 14 | | CopyOutResponse 15 | | CopyBothResponse 16 | | DataRow 17 | | EmptyQueryResponse 18 | | ErrorResponse 19 | | FunctionCallResponse 20 | | NoData 21 | | NoticeResponse 22 | | NotificationResponse 23 | | ParameterDescription 24 | | ParameterStatus 25 | | ParseComplete 26 | | PortalSuspended 27 | | ReadyForQuery 28 | | RowDescription 29 | [@@deriving compare, equal, sexp] 30 | 31 | type focus_on_message_error = 32 | | Unknown_message_type of char 33 | | Iobuf_too_short_for_header 34 | | Iobuf_too_short_for_message of { message_length : int } 35 | | Nonsense_message_length of int 36 | 37 | val constructor_of_char : char -> (constructor, focus_on_message_error) Result.t 38 | 39 | val focus_on_message 40 | : ([> read ], seek) Iobuf.t 41 | -> (constructor, focus_on_message_error) Result.t 42 | 43 | module Error_or_notice_field : sig 44 | type other = private char 45 | 46 | type t = 47 | | Severity 48 | | Severity_non_localised 49 | | Code 50 | | Message 51 | | Detail 52 | | Hint 53 | | Position 54 | | Internal_position 55 | | Internal_query 56 | | Where 57 | | Schema 58 | | Table 59 | | Column 60 | | Data_type 61 | | Constraint 62 | | File 63 | | Line 64 | | Routine 65 | | Other of other 66 | [@@deriving sexp_of, equal] 67 | end 68 | 69 | module ErrorResponse : sig 70 | (** In the protocol, the [Code] field is mandatory, so we also extract it to a separate 71 | non-optional record label. It will still appear in the [all_fields] list. *) 72 | type t = 73 | { error_code : string 74 | ; all_fields : (Error_or_notice_field.t * string) list 75 | } 76 | [@@deriving sexp_of] 77 | 78 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 79 | end 80 | 81 | module NoticeResponse : sig 82 | (** In the protocol, the [Code] field is mandatory, so we also extract it to a separate 83 | non-optional record label. It will still appear in the [all_fields] list. *) 84 | type t = 85 | { error_code : string 86 | ; all_fields : (Error_or_notice_field.t * string) list 87 | } 88 | [@@deriving sexp_of] 89 | 90 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 91 | end 92 | 93 | module AuthenticationRequest : sig 94 | (** [Ok] is not actually a request. It means that auth has succeeded. *) 95 | type t = 96 | | Ok 97 | | KerberosV5 98 | | CleartextPassword 99 | | MD5Password of { salt : string } 100 | | SCMCredential 101 | | GSS 102 | | SSPI 103 | | GSSContinue of { data : string } 104 | [@@deriving sexp] 105 | 106 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 107 | end 108 | 109 | module ParameterDescription : sig 110 | type t = int array 111 | 112 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 113 | end 114 | 115 | module ParameterStatus : sig 116 | type t = 117 | { key : string 118 | ; data : string 119 | } 120 | [@@deriving sexp_of] 121 | 122 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 123 | end 124 | 125 | module BackendKeyData : sig 126 | type t = Types.backend_key 127 | 128 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 129 | end 130 | 131 | module NotificationResponse : sig 132 | type t = 133 | { pid : Pid.t 134 | ; channel : Types.Notification_channel.t 135 | ; payload : string 136 | } 137 | [@@deriving sexp_of] 138 | 139 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 140 | end 141 | 142 | module ReadyForQuery : sig 143 | type t = 144 | | Idle 145 | | In_transaction 146 | | In_failed_transaction 147 | [@@deriving sexp_of] 148 | 149 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 150 | end 151 | 152 | module ParseComplete : sig 153 | val consume : ([> read ], seek) Iobuf.t -> unit 154 | end 155 | 156 | module BindComplete : sig 157 | val consume : ([> read ], seek) Iobuf.t -> unit 158 | end 159 | 160 | module NoData : sig 161 | val consume : ([> read ], seek) Iobuf.t -> unit 162 | end 163 | 164 | module EmptyQueryResponse : sig 165 | val consume : ([> read ], seek) Iobuf.t -> unit 166 | end 167 | 168 | module CloseComplete : sig 169 | val consume : ([> read ], seek) Iobuf.t -> unit 170 | end 171 | 172 | module RowDescription : sig 173 | (** Technically [format] could be [`Binary], but since [Frontend.Bind] doesn't ever ask 174 | for binary output right now, it's impossible to receive it from the server, and 175 | [consume] will reject it for simplicity. *) 176 | type t = Column_metadata.t iarray 177 | 178 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 179 | end 180 | 181 | module DataRow : sig 182 | type t = string option iarray 183 | 184 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 185 | val skip : ([> read ], seek) Iobuf.t -> unit 186 | end 187 | 188 | module type CopyResponse = sig 189 | (** Unlike in [RowDescription], it is possible to receive [`Binary] here because someone 190 | could put that option in their COPY query. [Postgres_async] will then abort the 191 | copy. *) 192 | type column = 193 | { name : string 194 | ; format : [ `Text | `Binary ] 195 | } 196 | 197 | type t = 198 | { overall_format : [ `Text | `Binary ] 199 | ; num_columns : int 200 | ; column_formats : [ `Text | `Binary ] array 201 | } 202 | [@@deriving compare, sexp_of] 203 | 204 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 205 | end 206 | 207 | module CopyInResponse : CopyResponse 208 | module CopyOutResponse : CopyResponse 209 | module CopyBothResponse : CopyResponse 210 | 211 | module CommandComplete : sig 212 | type t = string 213 | 214 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 215 | end 216 | 217 | module Writer : sig 218 | val auth_message : Writer.t -> AuthenticationRequest.t -> unit 219 | val ready_for_query : Writer.t -> ReadyForQuery.t -> unit 220 | val error_response : Writer.t -> ErrorResponse.t -> unit 221 | val backend_key : Writer.t -> Types.backend_key -> unit 222 | val parameter_description : Writer.t -> ParameterDescription.t -> unit 223 | val parameter_status : Writer.t -> ParameterStatus.t -> unit 224 | val command_complete : Writer.t -> CommandComplete.t -> unit 225 | val data_row : Writer.t -> DataRow.t -> unit 226 | val notice_response : Writer.t -> NoticeResponse.t -> unit 227 | val notification_response : Writer.t -> NotificationResponse.t -> unit 228 | val copy_data : Writer.t -> Shared.CopyData.t -> unit 229 | val copy_done : Writer.t -> unit 230 | val bind_complete : Writer.t -> unit 231 | val close_complete : Writer.t -> unit 232 | val empty_query_response : Writer.t -> unit 233 | val parse_complete : Writer.t -> unit 234 | val no_data : Writer.t -> unit 235 | end 236 | -------------------------------------------------------------------------------- /protocol/column_metadata.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | type t = 4 | { name : string 5 | ; format : [ `Text ] 6 | ; pg_type_oid : int 7 | } 8 | [@@deriving fields ~getters ~iterators:create, sexp_of] 9 | 10 | let create = Fields.create 11 | 12 | module type Public = Column_metadata_intf.Public with type t = t 13 | -------------------------------------------------------------------------------- /protocol/column_metadata.mli: -------------------------------------------------------------------------------- 1 | include Column_metadata_intf.Column_metadata (** @inline *) 2 | -------------------------------------------------------------------------------- /protocol/column_metadata_intf.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | module type Public = sig 4 | (** Contains information on the name and type of a column in query results *) 5 | type t 6 | 7 | val name : t -> string 8 | 9 | (** Oid of the type of data in the column. To get full type information for some 10 | [pg_type_oid t = K], [select * from pg_type where oid = K]. *) 11 | val pg_type_oid : t -> int 12 | end 13 | 14 | module type Column_metadata = sig 15 | type t [@@deriving sexp_of] 16 | 17 | val create : name:string -> format:[ `Text ] -> pg_type_oid:int -> t 18 | 19 | include Public with type t := t 20 | 21 | module type Public = Public with type t = t 22 | end 23 | -------------------------------------------------------------------------------- /protocol/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name postgres_async_protocol) 3 | (public_name postgres_async.protocol) 4 | (libraries async core_kernel.bus core core_kernel.iobuf) 5 | (preprocess 6 | (pps ppx_jane))) 7 | -------------------------------------------------------------------------------- /protocol/frontend.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | open! Import 4 | 5 | module SSLRequest = struct 6 | let message_type_char = None 7 | 8 | type t = unit 9 | 10 | let payload_length () = 4 11 | let validate_exn () = () 12 | 13 | (* These values look like dummy ones, but are the ones given in the postgres 14 | spec; the goal is to never collide with any protocol versions. *) 15 | let fill () iobuf = 16 | Iobuf.Fill.int16_be_trunc iobuf 1234; 17 | Iobuf.Fill.int16_be_trunc iobuf 5679 18 | ;; 19 | end 20 | 21 | module StartupMessage = struct 22 | let message_type_char = None 23 | 24 | module Parameter = struct 25 | module Name = struct 26 | include Shared.Null_terminated_string.Nonempty 27 | 28 | let validate_exn t = Validate.maybe_raise (Validate.name "parameter" (validate t)) 29 | let user = "user" 30 | let database = "database" 31 | let replication = "replication" 32 | let options = "options" 33 | 34 | (** "In addition to the above, other parameters may be listed. Parameter names 35 | beginning with _pq_. are reserved for use as protocol extensions, while others 36 | are treated as run-time parameters to be set at backend start time." *) 37 | let reserved_protocol_extension_prefix = "_pq_." 38 | 39 | let is_protocol_extension = 40 | String.is_prefix ~prefix:reserved_protocol_extension_prefix 41 | ;; 42 | end 43 | 44 | module Value = Shared.Null_terminated_string 45 | 46 | module Options = struct 47 | let escape_char = '\\' 48 | let escapeworthy = lazy (List.filter [%all: Char.t] ~f:Char.is_whitespace) 49 | 50 | let encode = 51 | let escape = 52 | lazy 53 | (unstage 54 | (String.Escaping.escape ~escapeworthy:(force escapeworthy) ~escape_char)) 55 | in 56 | fun options -> 57 | List.map options ~f:(fun s -> 58 | if String.is_empty s 59 | then invalid_arg "cannot encode empty arguments" 60 | else if String.mem s '\x00' 61 | then invalid_arg "cannot encode null characters" 62 | else force escape s) 63 | |> String.concat ~sep:" " 64 | ;; 65 | 66 | let decode = 67 | let unescape = lazy (unstage (String.Escaping.unescape ~escape_char)) in 68 | fun s -> 69 | String.Escaping.split_on_chars s ~on:(force escapeworthy) ~escape_char 70 | |> List.filter_map ~f:(function 71 | | "" -> None 72 | | s -> Some (force unescape s)) 73 | ;; 74 | 75 | let%test_unit "round-trip" = 76 | let generator = 77 | String.gen_nonempty' 78 | Quickcheck.Generator.( 79 | union 80 | [ of_list (force escapeworthy) 81 | ; return escape_char 82 | ; Char.gen_uniform_inclusive '\x01' Char.max_value 83 | ]) 84 | in 85 | Quickcheck.test (List.gen_non_empty generator) ~f:(fun original -> 86 | [%test_result: string list] (decode (encode original)) ~expect:original) 87 | ;; 88 | end 89 | end 90 | 91 | (** The protocol version number. The most significant 16 bits are the major version 92 | number (3 for the protocol described here). The least significant 16 bits are the 93 | minor version number (0 for the protocol described here). *) 94 | let this_protocol = 0x00030000 95 | 96 | (** postgres sets an arbitrary limit on startup packet length to prevent DoS. This limit 97 | has been unchanged from 2003-2024 so it seems pretty reasonable to hard code here. *) 98 | let max_startup_packet_length = 10_000 99 | 100 | type t = Parameter.Value.t Map.M(Parameter.Name).t 101 | [@@deriving compare, quickcheck, sexp_of] 102 | 103 | let quickcheck_generator = 104 | let open Quickcheck.Generator.Let_syntax in 105 | let%map user = [%quickcheck.generator: Parameter.Value.t] 106 | and t = [%quickcheck.generator: t] in 107 | Map.set t ~key:Parameter.Name.user ~data:user 108 | ;; 109 | 110 | let find = Map.find 111 | 112 | (** "user" is the only required parameter *) 113 | let user t = 114 | let key = Parameter.Name.user in 115 | match find t key with 116 | | Some user -> user 117 | | None -> raise_s [%sexp (key : Parameter.Name.t), "is missing from startup message"] 118 | ;; 119 | 120 | let database_defaulting_to_user t = 121 | match find t Parameter.Name.database with 122 | | Some database -> database 123 | | None -> user t 124 | ;; 125 | 126 | let options t = 127 | match Map.find t Parameter.Name.options with 128 | | None -> [] 129 | | Some options -> Parameter.Options.decode options 130 | ;; 131 | 132 | let runtime_parameters t = 133 | List.fold 134 | Parameter.Name.[ user; database; options; replication ] 135 | ~init:t 136 | ~f:Map.remove 137 | |> Map.filter_keys ~f:(Fn.non Parameter.Name.is_protocol_extension) 138 | ;; 139 | 140 | let protocol_extensions t = Map.filter_keys t ~f:Parameter.Name.is_protocol_extension 141 | 142 | let payload_length t = 143 | (* protocol version number: *) 144 | 4 145 | (* parameters: *) 146 | + Map.sumi 147 | (module Int) 148 | t 149 | ~f:(fun ~key ~data -> 150 | Parameter.Name.payload_length key + Parameter.Value.payload_length data) 151 | (* trailing null byte *) 152 | + 1 153 | ;; 154 | 155 | let validate_exn t = 156 | let (_ : string) = 157 | (* ensure the required parameter is present *) 158 | user t 159 | in 160 | Map.iteri t ~f:(fun ~key:field_name ~data -> 161 | Parameter.Name.validate_exn field_name; 162 | Validate.maybe_raise (Validate.name field_name (Parameter.Value.validate data))); 163 | if payload_length t > max_startup_packet_length 164 | then ( 165 | let largest_field = 166 | Map.to_alist t 167 | |> List.max_elt 168 | ~compare: 169 | (Comparable.lift [%compare: int] ~f:(fun (key, data) -> 170 | String.length key + String.length data)) 171 | |> Option.value_exn 172 | |> fst 173 | in 174 | raise_s [%sexp "StartupMessage is too large", ~~(largest_field : string)]) 175 | ;; 176 | 177 | let of_parameters_exn t = 178 | validate_exn t; 179 | t 180 | ;; 181 | 182 | let create_exn 183 | ~user 184 | ?database 185 | ?replication 186 | ?options 187 | ?runtime_parameters:(unvalidated_runtime_parameters = String.Map.empty) 188 | ?(protocol_extensions = String.Map.empty) 189 | () 190 | = 191 | let runtime_parameters = runtime_parameters unvalidated_runtime_parameters in 192 | let () = 193 | let invalid_runtime_parameters = 194 | Set.diff 195 | (Map.key_set unvalidated_runtime_parameters) 196 | (Map.key_set runtime_parameters) 197 | in 198 | if not (Set.is_empty invalid_runtime_parameters) 199 | then raise_s [%sexp ~~(invalid_runtime_parameters : String.Set.t)] 200 | in 201 | let () = 202 | let invalid_protocol_extensions = 203 | Map.key_set protocol_extensions 204 | |> Set.filter ~f:(Fn.non Parameter.Name.is_protocol_extension) 205 | in 206 | if not (Set.is_empty invalid_protocol_extensions) 207 | then raise_s [%sexp ~~(invalid_protocol_extensions : String.Set.t)] 208 | in 209 | Map.merge_disjoint_exn runtime_parameters protocol_extensions 210 | |> Map.add_exn ~key:Parameter.Name.user ~data:user 211 | |> (match database with 212 | | None -> Fn.id 213 | | Some data -> Map.add_exn ~key:Parameter.Name.database ~data) 214 | |> (match options with 215 | | None -> Fn.id 216 | | Some options -> 217 | Map.add_exn ~key:Parameter.Name.options ~data:(Parameter.Options.encode options)) 218 | |> (match replication with 219 | | None -> Fn.id 220 | | Some data -> Map.add_exn ~key:Parameter.Name.replication ~data) 221 | |> of_parameters_exn 222 | ;; 223 | 224 | let fill t iobuf = 225 | validate_exn t; 226 | Iobuf.Fill.int32_be_trunc iobuf this_protocol; 227 | Map.iteri t ~f:(fun ~key ~data -> 228 | Parameter.Name.fill iobuf key; 229 | Parameter.Value.fill iobuf data); 230 | Iobuf.Fill.char iobuf '\x00' 231 | ;; 232 | 233 | let consume_exn iobuf = 234 | let rec consume_parameters ~iobuf ~parameters = 235 | match Parameter.Name.consume_exn iobuf with 236 | | Error `Empty_string -> of_parameters_exn parameters 237 | | Ok key -> 238 | let data = Parameter.Value.consume_exn iobuf in 239 | consume_parameters ~iobuf ~parameters:(Map.set parameters ~key ~data) 240 | in 241 | if Iobuf.length iobuf > max_startup_packet_length 242 | then failwith "StartupMessage is too large"; 243 | let protocol = Iobuf.Consume.int32_be iobuf in 244 | [%test_result: int] protocol ~expect:this_protocol; 245 | consume_parameters ~iobuf ~parameters:String.Map.empty [@nontail] 246 | ;; 247 | 248 | let%test_unit "round-trip" = 249 | Quickcheck.test [%quickcheck.generator: t] ~f:(fun original -> 250 | let len = payload_length original in 251 | let iobuf = Iobuf.create ~len in 252 | fill original iobuf; 253 | assert (Iobuf.is_empty iobuf); 254 | Iobuf.flip_lo iobuf; 255 | let consumed = consume_exn iobuf in 256 | assert (Iobuf.is_empty iobuf); 257 | [%test_result: t] consumed ~expect:original) 258 | ;; 259 | 260 | let consume iobuf = 261 | match consume_exn iobuf with 262 | | t -> Ok t 263 | | exception exn -> 264 | Or_error.of_exn exn |> Or_error.tag ~tag:"Failed to parse StartupMessage" 265 | ;; 266 | end 267 | 268 | module PasswordMessage = struct 269 | let message_type_char = Some 'p' 270 | 271 | type t = 272 | | Cleartext_or_md5_hex of string 273 | | Gss_binary_blob of string 274 | 275 | let validate_exn = function 276 | | Cleartext_or_md5_hex password -> 277 | Shared.validate_null_terminated_exn ~field_name:"password" password 278 | | Gss_binary_blob _ -> () 279 | ;; 280 | 281 | let payload_length = function 282 | | Cleartext_or_md5_hex password -> String.length password + 1 283 | | Gss_binary_blob blob -> String.length blob 284 | ;; 285 | 286 | let fill t iobuf = 287 | match t with 288 | | Cleartext_or_md5_hex password -> Shared.fill_null_terminated iobuf password 289 | | Gss_binary_blob blob -> Iobuf.Fill.stringo iobuf blob 290 | ;; 291 | 292 | let consume_krb_exn iobuf ~length = 293 | let blob = Iobuf.Consume.string iobuf ~len:length ~str_pos:0 in 294 | Gss_binary_blob blob 295 | ;; 296 | 297 | let consume_krb iobuf ~length = 298 | match consume_krb_exn iobuf ~length with 299 | | exception exn -> 300 | error_s [%message "Failed to parse expected GSS PasswordMessage" (exn : Exn.t)] 301 | | t -> Ok t 302 | ;; 303 | 304 | let consume_password iobuf = 305 | match Shared.consume_cstring_exn iobuf with 306 | | exception exn -> 307 | error_s [%message "Failed to parse expected PasswordMessage" (exn : Exn.t)] 308 | | str -> Ok (Cleartext_or_md5_hex str) 309 | ;; 310 | end 311 | 312 | module Parse = struct 313 | let message_type_char = Some 'P' 314 | 315 | type t = 316 | { destination : Types.Statement_name.t 317 | ; query : string 318 | } 319 | 320 | let validate_exn t = Shared.validate_null_terminated_exn t.query ~field_name:"query" 321 | 322 | let payload_length t = 323 | String.length (Types.Statement_name.to_string t.destination) 324 | + 1 325 | + String.length t.query 326 | + 1 327 | + 2 328 | ;; 329 | 330 | let fill t iobuf = 331 | Shared.fill_null_terminated iobuf (Types.Statement_name.to_string t.destination); 332 | Shared.fill_null_terminated iobuf t.query; 333 | (* zero parameter types: *) 334 | Iobuf.Fill.int16_be_trunc iobuf 0 335 | ;; 336 | end 337 | 338 | module Bind = struct 339 | let message_type_char = Some 'B' 340 | 341 | type t = 342 | { destination : Types.Portal_name.t 343 | ; statement : Types.Statement_name.t 344 | ; parameters : string option array 345 | } 346 | 347 | let validate_exn (_ : t) = () 348 | 349 | let payload_length t = 350 | let parameter_length = function 351 | | None -> 0 352 | | Some s -> String.length s 353 | in 354 | (* destination and terminator *) 355 | String.length (Types.Portal_name.to_string t.destination) 356 | + 1 357 | (* statement and terminator *) 358 | + String.length (Types.Statement_name.to_string t.statement) 359 | + 1 360 | (* # parameter format codes = 1: *) 361 | + 2 362 | (* single parameter format code: *) 363 | + 2 364 | (* # parameters: *) 365 | + 2 366 | (* parameter sizes: *) 367 | + (4 * Array.length t.parameters) (* parameters *) 368 | + Array.sum (module Int) t.parameters ~f:parameter_length 369 | (* # result format codes = 1: *) 370 | + 2 371 | (* single result format code: *) 372 | + 2 373 | ;; 374 | 375 | let fill t iobuf = 376 | Shared.fill_null_terminated iobuf (Types.Portal_name.to_string t.destination); 377 | Shared.fill_null_terminated iobuf (Types.Statement_name.to_string t.statement); 378 | (* 1 parameter format code: *) 379 | Iobuf.Fill.int16_be_trunc iobuf 1; 380 | (* all parameters are text: *) 381 | Iobuf.Fill.int16_be_trunc iobuf 0; 382 | let num_parameters = Array.length t.parameters in 383 | Shared.fill_uint16_be iobuf num_parameters; 384 | for idx = 0 to num_parameters - 1 do 385 | match t.parameters.(idx) with 386 | | None -> Shared.fill_int32_be iobuf (-1) 387 | | Some str -> 388 | Shared.fill_int32_be iobuf (String.length str); 389 | Iobuf.Fill.stringo iobuf str 390 | done; 391 | (* 1 result format code: *) 392 | Iobuf.Fill.int16_be_trunc iobuf 1; 393 | (* all results are text: *) 394 | Iobuf.Fill.int16_be_trunc iobuf 0 395 | ;; 396 | end 397 | 398 | module Execute = struct 399 | let message_type_char = Some 'E' 400 | 401 | type num_rows = 402 | | Unlimited 403 | | Limit of int 404 | 405 | type t = 406 | { portal : Types.Portal_name.t 407 | ; limit : num_rows 408 | } 409 | 410 | let validate_exn t = 411 | match t.limit with 412 | | Unlimited -> () 413 | | Limit n -> if n <= 0 then failwith "When provided, num rows limit must be positive" 414 | ;; 415 | 416 | let payload_length t = +String.length (Types.Portal_name.to_string t.portal) + 1 + 4 417 | 418 | let fill t iobuf = 419 | Shared.fill_null_terminated iobuf (Types.Portal_name.to_string t.portal); 420 | let limit = 421 | match t.limit with 422 | | Unlimited -> 0 423 | | Limit n -> n 424 | in 425 | Shared.fill_int32_be iobuf limit 426 | ;; 427 | end 428 | 429 | module Statement_or_portal_action = struct 430 | type t = 431 | | Statement of Types.Statement_name.t 432 | | Portal of Types.Portal_name.t 433 | 434 | let validate_exn (_ : t) = () 435 | 436 | let payload_length t = 437 | let str = 438 | match t with 439 | | Statement s -> Types.Statement_name.to_string s 440 | | Portal s -> Types.Portal_name.to_string s 441 | in 442 | 1 + String.length str + 1 443 | ;; 444 | 445 | let fill t iobuf = 446 | match t with 447 | | Statement s -> 448 | Iobuf.Fill.char iobuf 'S'; 449 | Shared.fill_null_terminated iobuf (Types.Statement_name.to_string s) 450 | | Portal p -> 451 | Iobuf.Fill.char iobuf 'P'; 452 | Shared.fill_null_terminated iobuf (Types.Portal_name.to_string p) 453 | ;; 454 | end 455 | 456 | module Describe = struct 457 | let message_type_char = Some 'D' 458 | 459 | include Statement_or_portal_action 460 | end 461 | 462 | module Close = struct 463 | let message_type_char = Some 'C' 464 | 465 | include Statement_or_portal_action 466 | end 467 | 468 | module CopyFail = struct 469 | let message_type_char = Some 'f' 470 | 471 | type t = { reason : string } 472 | 473 | let validate_exn t = Shared.validate_null_terminated_exn t.reason ~field_name:"reason" 474 | let payload_length t = String.length t.reason + 1 475 | let fill t iobuf = Shared.fill_null_terminated iobuf t.reason 476 | end 477 | 478 | module Query = struct 479 | type t = string 480 | 481 | let message_type_char = Some 'Q' 482 | let consume_exn = Shared.consume_cstring_exn 483 | 484 | let consume iobuf = 485 | match consume_exn iobuf with 486 | | exception exn -> error_s [%message "Failed to parse Query" (exn : Exn.t)] 487 | | t -> Ok t 488 | ;; 489 | 490 | let validate_exn (_ : t) = () 491 | let payload_length t = 1 + String.length t 492 | let fill t iobuf = Shared.fill_null_terminated iobuf t 493 | end 494 | 495 | module CancelRequest = struct 496 | let message_type_char = None 497 | 498 | type t = 499 | { pid : int 500 | ; secret : int 501 | } 502 | 503 | let validate_exn (_ : t) = () 504 | 505 | let payload_length (_ : t) = 506 | (* Cancel request code = 12345678 *) 507 | 4 508 | + (* Pid*) 509 | 4 510 | + (* Secret *) 511 | 4 512 | ;; 513 | 514 | let fill t iobuf = 515 | Iobuf.Fill.int32_be_trunc iobuf 80877102; 516 | Iobuf.Fill.int32_be_trunc iobuf t.pid; 517 | Iobuf.Fill.int32_be_trunc iobuf t.secret 518 | ;; 519 | 520 | let consume_exn iobuf = 521 | let pid = Iobuf.Consume.int32_be iobuf in 522 | let secret = Iobuf.Consume.int32_be iobuf in 523 | { pid; secret } 524 | ;; 525 | 526 | let consume iobuf = 527 | match consume_exn iobuf with 528 | | exception exn -> error_s [%message "Failed to parse CacnelRequest" (exn : Exn.t)] 529 | | t -> Ok t 530 | ;; 531 | end 532 | 533 | module No_arg : sig 534 | val flush : string 535 | val sync : string 536 | val copy_done : string 537 | val terminate : string 538 | end = struct 539 | include Shared.No_arg 540 | 541 | let flush = gen ~constructor:'H' 542 | let sync = gen ~constructor:'S' 543 | let copy_done = gen ~constructor:'c' 544 | let terminate = gen ~constructor:'X' 545 | end 546 | 547 | include No_arg 548 | 549 | module Writer = struct 550 | open Async 551 | 552 | let write_message = Shared.write_message 553 | let ssl_request = Staged.unstage (write_message (module SSLRequest)) 554 | let startup_message = Staged.unstage (write_message (module StartupMessage)) 555 | let password_message = Staged.unstage (write_message (module PasswordMessage)) 556 | let parse = Staged.unstage (write_message (module Parse)) 557 | let bind = Staged.unstage (write_message (module Bind)) 558 | let close = Staged.unstage (write_message (module Close)) 559 | let query = Staged.unstage (write_message (module Query)) 560 | let describe = Staged.unstage (write_message (module Describe)) 561 | let execute = Staged.unstage (write_message (module Execute)) 562 | let copy_fail = Staged.unstage (write_message (module CopyFail)) 563 | let copy_data = Staged.unstage (write_message (module Shared.CopyData)) 564 | let cancel_request = Staged.unstage (write_message (module CancelRequest)) 565 | let flush writer = Writer.write writer No_arg.flush 566 | let sync writer = Writer.write writer No_arg.sync 567 | let copy_done writer = Writer.write writer No_arg.copy_done 568 | let terminate writer = Writer.write writer No_arg.terminate 569 | end 570 | -------------------------------------------------------------------------------- /protocol/frontend.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | open! Import 4 | 5 | module StartupMessage : sig 6 | module Parameter : sig 7 | module Name : sig 8 | (** The currently recognized parameter names. The following descriptions are copied 9 | from 10 | {{:https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE} 11 | the PostgreSQL docs}: *) 12 | 13 | (** The database user name to connect as. Required; there is no default. *) 14 | val user : string 15 | 16 | (** The database to connect to. Defaults to the user name. *) 17 | val database : string 18 | 19 | (** Command-line arguments for the backend. (This is deprecated in favor of setting 20 | individual run-time parameters.) See the [Options] module for an implementation 21 | of this. *) 22 | val options : string 23 | 24 | (** Used to connect in streaming replication mode, where a small set of replication 25 | commands can be issued instead of SQL statements. Value can be true, false, or 26 | database, and the default is false. See 27 | {{:https://www.postgresql.org/docs/13/protocol-replication.html} Section 52.4} 28 | for details. *) 29 | val replication : string 30 | end 31 | 32 | module Options : sig 33 | (** Command-line arguments for the backend are delimited by spaces in the value sent 34 | to the server. These functions convert between a list of arguments and the 35 | equivalent parameter value to send to the server *) 36 | 37 | val encode : string list -> string 38 | val decode : string -> string list 39 | end 40 | end 41 | 42 | (** Must contain a "user" parameter. Keys may not be empty and no strings may contain a 43 | null byte. *) 44 | type t = private string String.Map.t [@@deriving compare, sexp_of] 45 | 46 | val find : t -> string -> string option 47 | 48 | (** user is required, all other parameters are optional - they can be accessed via 49 | [find] *) 50 | val user : t -> string 51 | 52 | (** This replicates postgres' behavior. The database is optional but if the field is 53 | missing then postgres will default to user. *) 54 | val database_defaulting_to_user : t -> string 55 | 56 | (** [options] defaults to the empty list *) 57 | val options : t -> string list 58 | 59 | (** These settings will be applied during backend start (after parsing the command-line 60 | arguments if any) and will act as session defaults. *) 61 | val runtime_parameters : t -> string String.Map.t 62 | 63 | val protocol_extensions : t -> string String.Map.t 64 | 65 | val create_exn 66 | : user:string 67 | -> ?database:string 68 | -> ?replication:string 69 | -> ?options:string list 70 | -> ?runtime_parameters:string String.Map.t 71 | -> ?protocol_extensions:string String.Map.t 72 | -> unit 73 | -> t 74 | 75 | val of_parameters_exn : string String.Map.t -> t 76 | val consume : (t Or_error.t, [> read ], seek) Iobuf.Consume.t 77 | end 78 | 79 | module PasswordMessage : sig 80 | type t = 81 | | Cleartext_or_md5_hex of string 82 | | Gss_binary_blob of string 83 | 84 | val consume_krb : ([> read ], seek) Iobuf.t -> length:int -> t Or_error.t 85 | val consume_password : ([> read ], seek) Iobuf.t -> t Or_error.t 86 | end 87 | 88 | module Parse : sig 89 | type t = 90 | { destination : Types.Statement_name.t 91 | ; query : string 92 | } 93 | end 94 | 95 | module Bind : sig 96 | type t = 97 | { destination : Types.Portal_name.t 98 | ; statement : Types.Statement_name.t 99 | ; parameters : string option array 100 | } 101 | end 102 | 103 | module Execute : sig 104 | type num_rows = 105 | | Unlimited 106 | | Limit of int 107 | 108 | type t = 109 | { portal : Types.Portal_name.t 110 | ; limit : num_rows 111 | } 112 | end 113 | 114 | module Describe : sig 115 | type t = 116 | | Statement of Types.Statement_name.t 117 | | Portal of Types.Portal_name.t 118 | end 119 | 120 | module Close : sig 121 | type t = 122 | | Statement of Types.Statement_name.t 123 | | Portal of Types.Portal_name.t 124 | end 125 | 126 | module CopyFail : sig 127 | type t = { reason : string } 128 | end 129 | 130 | module Query : sig 131 | type t = string 132 | 133 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 134 | end 135 | 136 | module CancelRequest : sig 137 | type t = 138 | { pid : int 139 | ; secret : int 140 | } 141 | 142 | val consume : ([> read ], seek) Iobuf.t -> t Or_error.t 143 | end 144 | 145 | module Writer : sig 146 | open Async 147 | 148 | val ssl_request : Writer.t -> unit -> unit 149 | val startup_message : Writer.t -> StartupMessage.t -> unit 150 | val password_message : Writer.t -> PasswordMessage.t -> unit 151 | val parse : Writer.t -> Parse.t -> unit 152 | val bind : Writer.t -> Bind.t -> unit 153 | val close : Writer.t -> Close.t -> unit 154 | val query : Writer.t -> Query.t -> unit 155 | val describe : Writer.t -> Describe.t -> unit 156 | val execute : Writer.t -> Execute.t -> unit 157 | val copy_fail : Writer.t -> CopyFail.t -> unit 158 | val copy_data : Writer.t -> Shared.CopyData.t -> unit 159 | val cancel_request : Writer.t -> CancelRequest.t -> unit 160 | val flush : Writer.t -> unit 161 | val sync : Writer.t -> unit 162 | val copy_done : Writer.t -> unit 163 | val terminate : Writer.t -> unit 164 | end 165 | -------------------------------------------------------------------------------- /protocol/import.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | 4 | type seek = Iobuf.seek 5 | -------------------------------------------------------------------------------- /protocol/postgres_async_protocol.ml: -------------------------------------------------------------------------------- 1 | (** See https://www.postgresql.org/docs/13/protocol-message-formats.html for details *) 2 | 3 | include Types 4 | module Types = Types 5 | module Shared = Shared 6 | module Backend = Backend 7 | module Column_metadata = Column_metadata 8 | module Frontend = Frontend 9 | -------------------------------------------------------------------------------- /protocol/shared.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | open! Import 4 | 5 | (* When [Word_size.word_size = W32], [int] can take at most 31 bits, so the max is 2**30. 6 | 7 | [Iobuf.Consume.int32_be] silently truncates. This is a shame (ideally it would return 8 | a [Int63.t]; see the comment at the top of [Int32.t], but otherwise an [Int32.t] would 9 | certainly suffice) and makes it quite annoying to safely implement something that reads 10 | 32 bit ints on a 32 bit ocaml platform. 11 | 12 | Looking below, the 32 bit ints are 13 | 14 | - [message_length]: we'd refuse to read messages larger than 2**30 anyway, 15 | - [num_fields] (in a row): guaranteed to be < 1600 by postgres 16 | (https://www.postgresql.org/docs/11/ddl-basics.html), 17 | - [pid]: on linux, less than 2**22 (man 5 proc), 18 | - [secret] (from [BackendKeyData]): no guarantees. 19 | 20 | so, for now it seems safe enough to to stumble on in 32-bit-mode even though iobuf 21 | would silently truncate the ints. This is unsatisfying (besides not supporting reading 22 | [secret]) because if we've made a mistake, or have a bug, we'd rather crash on the 23 | protocol error than truncate. 24 | 25 | We'll revisit it if someone wants it. *) 26 | 27 | module Non_zero_char = struct 28 | type t = char [@@deriving compare, equal, hash, quickcheck, sexp] 29 | 30 | let quickcheck_generator = Char.gen_uniform_inclusive '\x01' Char.max_value 31 | end 32 | 33 | module Null_terminated_string = struct 34 | module T = struct 35 | type t = string [@@deriving compare, equal, hash, quickcheck, sexp] 36 | type comparator_witness = String.comparator_witness 37 | 38 | let comparator = String.comparator 39 | let quickcheck_generator = String.gen' [%quickcheck.generator: Non_zero_char.t] 40 | 41 | let validate t = 42 | match String.mem t '\x00' with 43 | | false -> Validate.pass 44 | | true -> Validate.fail "string may not contain nulls" 45 | ;; 46 | 47 | let payload_length t = String.length t + 1 48 | 49 | let fill iobuf t = 50 | Iobuf.Fill.stringo iobuf t; 51 | Iobuf.Fill.char iobuf '\x00' 52 | ;; 53 | 54 | let consume_exn iobuf = 55 | match Iobuf.Peek.index iobuf '\x00' with 56 | | None -> 57 | raise 58 | (Not_found_s 59 | [%sexp 60 | "Null_terminated_string.consume_exn: could not find terminating null byte"]) 61 | | Some len -> 62 | let t = Iobuf.Consume.stringo iobuf ~len in 63 | [%test_result: char] (Iobuf.Consume.char iobuf) ~expect:'\x00'; 64 | t 65 | ;; 66 | end 67 | 68 | include T 69 | 70 | module Nonempty = struct 71 | include T 72 | 73 | let quickcheck_generator = 74 | String.gen_nonempty' [%quickcheck.generator: Non_zero_char.t] 75 | ;; 76 | 77 | let validate t = 78 | if String.is_empty t then Validate.fail "string may not be empty" else validate t 79 | ;; 80 | 81 | let consume_exn iobuf = 82 | match consume_exn iobuf with 83 | | "" -> Error `Empty_string 84 | | s -> Ok s 85 | ;; 86 | end 87 | end 88 | 89 | let validate_null_terminated_exn ~field_name str = 90 | if String.mem str '\x00' 91 | then raise_s [%message "String may not contain nulls" field_name str] 92 | ;; 93 | 94 | let fill_null_terminated iobuf str = 95 | Iobuf.Fill.stringo iobuf str; 96 | Iobuf.Fill.char iobuf '\x00' 97 | ;; 98 | 99 | (* Type int16 could be used as both unsigned and signed, depending on the context (See 100 | pqPutInt/pgGetInt in the interfaces/libpq/fe-misc.c). 101 | 102 | For example, data type size in RowDescription could be negative, but parameter count 103 | in Bind message could only be positive, and so it's maximum value is 65535. Our 104 | implementation currently only needs unsigned int16 values *) 105 | let uint16_min = 0 106 | let uint16_max = 65535 107 | 108 | let int32_min = 109 | match Word_size.word_size with 110 | | W64 -> Int32.to_int_exn Int32.min_value 111 | | W32 -> Int.min_value 112 | ;; 113 | 114 | let int32_max = 115 | match Word_size.word_size with 116 | | W64 -> Int32.to_int_exn Int32.max_value 117 | | W32 -> Int.max_value 118 | ;; 119 | 120 | let () = 121 | match Word_size.word_size with 122 | | W64 -> 123 | assert (String.equal (Int.to_string int32_min) "-2147483648"); 124 | assert (String.equal (Int.to_string int32_max) "2147483647") 125 | | W32 -> 126 | assert (String.equal (Int.to_string int32_min) "-1073741824"); 127 | assert (String.equal (Int.to_string int32_max) "1073741823") 128 | ;; 129 | 130 | let[@inline always] fill_uint16_be iobuf value = 131 | match uint16_min <= value && value <= uint16_max with 132 | | true -> Iobuf.Fill.uint16_be_trunc iobuf value 133 | | false -> failwithf "uint16 out of range: %i" value () 134 | ;; 135 | 136 | let[@inline always] fill_int32_be iobuf value = 137 | match int32_min <= value && value <= int32_max with 138 | | true -> Iobuf.Fill.int32_be_trunc iobuf value 139 | | false -> failwithf "int32 out of range: %i" value () 140 | ;; 141 | 142 | let find_null_exn iobuf = 143 | let rec loop ~iobuf ~length ~pos = 144 | if Char.( = ) (Iobuf.Peek.char iobuf ~pos) '\x00' 145 | then pos 146 | else if pos > length - 1 147 | then failwith "find_null_exn could not find \\x00" 148 | else loop ~iobuf ~length ~pos:(pos + 1) 149 | in 150 | loop ~iobuf ~length:(Iobuf.length iobuf) ~pos:0 151 | ;; 152 | 153 | let consume_cstring_exn iobuf = 154 | let len = find_null_exn iobuf in 155 | let res = Iobuf.Consume.string iobuf ~len ~str_pos:0 in 156 | let zero = Iobuf.Consume.char iobuf in 157 | assert (Char.( = ) zero '\x00'); 158 | res 159 | ;; 160 | 161 | module type Message_type = sig 162 | val message_type_char : char option 163 | 164 | type t 165 | 166 | val validate_exn : t -> unit 167 | val payload_length : t -> int 168 | val fill : t -> (read_write, Iobuf.seek) Iobuf.t -> unit 169 | end 170 | 171 | type 'a with_computed_length = 172 | { payload_length : int 173 | ; value : 'a 174 | } 175 | 176 | let write_message (type a) (module M : Message_type with type t = a) = 177 | let full_length { payload_length; _ } = 178 | match M.message_type_char with 179 | | None -> payload_length + 4 180 | | Some _ -> payload_length + 5 181 | in 182 | let blit_to_bigstring with_computed_length bigstring ~pos = 183 | let iobuf = 184 | Iobuf.of_bigstring bigstring ~pos ~len:(full_length with_computed_length) 185 | in 186 | (match M.message_type_char with 187 | | None -> () 188 | | Some c -> Iobuf.Fill.char iobuf c); 189 | let { payload_length; value } = with_computed_length in 190 | fill_int32_be iobuf (payload_length + 4); 191 | M.fill value iobuf; 192 | match Iobuf.is_empty iobuf with 193 | | true -> () 194 | | false -> failwith "postgres message filler lied about length" 195 | in 196 | Staged.stage (fun writer value -> 197 | M.validate_exn value; 198 | let payload_length = M.payload_length value in 199 | Async.Writer.write_gen_whole 200 | writer 201 | { payload_length; value } 202 | ~length:full_length 203 | ~blit_to_bigstring) 204 | ;; 205 | 206 | module No_arg = struct 207 | let gen ~constructor = 208 | let tmp = Iobuf.create ~len:5 in 209 | Iobuf.Poke.char tmp ~pos:0 constructor; 210 | (* fine to use Iobuf's int32 function, as [4] is clearly in range. *) 211 | Iobuf.Poke.int32_be_trunc tmp ~pos:1 4; 212 | Iobuf.to_string tmp 213 | ;; 214 | end 215 | 216 | (* Both the backend and frontend use the same format for [CopyData] and 217 | [CopyDone] messages, hence they are placed in [Shared]. *) 218 | module CopyData = struct 219 | let message_type_char = Some 'd' 220 | 221 | type t = string 222 | 223 | (* After [focus_on_message] seeks over the type and length, 'CopyData' 224 | messages are simply just the payload bytes. *) 225 | let skip iobuf = Iobuf.advance iobuf (Iobuf.length iobuf) 226 | let validate_exn (_ : t) = () 227 | let payload_length t = String.length t 228 | let fill t iobuf = Iobuf.Fill.stringo iobuf t 229 | end 230 | 231 | module CopyDone = struct 232 | let consume (_ : _ Iobuf.t) = () 233 | end 234 | -------------------------------------------------------------------------------- /protocol/shared.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | open! Import 4 | 5 | module Non_zero_char : sig 6 | type t = char [@@deriving compare, equal, hash, quickcheck, sexp] 7 | end 8 | 9 | module Null_terminated_string : sig 10 | type t = string [@@deriving compare, equal, hash, quickcheck, sexp] 11 | type comparator_witness = String.comparator_witness 12 | 13 | include Comparator.S with type t := t and type comparator_witness := comparator_witness 14 | 15 | val validate : t Validate.check 16 | val payload_length : t -> int 17 | val fill : (t, read_write, seek) Iobuf.Fill.t 18 | val consume_exn : (t, [> read ], seek) Iobuf.Consume.t 19 | 20 | module Nonempty : sig 21 | type nonrec t = t [@@deriving compare, equal, hash, quickcheck, sexp] 22 | 23 | include Comparator.S with type t := t and type comparator_witness = comparator_witness 24 | 25 | val validate : t -> Validate.t 26 | val payload_length : t -> int 27 | val fill : (t, read_write, seek) Iobuf.Fill.t 28 | val consume_exn : ((t, [> `Empty_string ]) result, [> read ], seek) Iobuf.Consume.t 29 | end 30 | end 31 | 32 | val validate_null_terminated_exn : field_name:string -> string -> unit 33 | val fill_null_terminated : (read_write, seek) Iobuf.t -> string -> unit 34 | val uint16_min : int 35 | val uint16_max : int 36 | val int32_min : int 37 | val int32_max : int 38 | val fill_uint16_be : (read_write, seek) Iobuf.t -> int -> unit 39 | val fill_int32_be : (read_write, seek) Iobuf.t -> int -> unit 40 | val find_null_exn : ([> read ], 'a) Iobuf.t -> int 41 | val consume_cstring_exn : ([> read ], seek) Iobuf.t -> string 42 | 43 | module type Message_type = sig 44 | val message_type_char : char option 45 | 46 | type t 47 | 48 | val validate_exn : t -> unit 49 | val payload_length : t -> int 50 | val fill : t -> (read_write, seek) Iobuf.t -> unit 51 | end 52 | 53 | val write_message 54 | : (module Message_type with type t = 'a) 55 | -> (Writer.t -> 'a -> unit) Staged.t 56 | 57 | module No_arg : sig 58 | val gen : constructor:char -> string 59 | end 60 | 61 | module CopyData : sig 62 | include Message_type with type t = string 63 | 64 | val message_type_char : char option 65 | val skip : ([> read ], seek) Iobuf.t -> unit 66 | end 67 | 68 | module CopyDone : sig 69 | val consume : ([> read ], seek) Iobuf.t -> unit 70 | end 71 | -------------------------------------------------------------------------------- /protocol/types.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open! Int.Replace_polymorphic_compare 3 | 4 | type backend_key = 5 | { pid : int 6 | ; secret : int 7 | } 8 | [@@deriving sexp_of] 9 | 10 | module type Named_or_unnamed = sig 11 | type t 12 | 13 | val unnamed : t 14 | val create_named_exn : string -> t 15 | val to_string : t -> string 16 | end 17 | 18 | module Named_or_unnamed = struct 19 | type t = string 20 | 21 | let unnamed = "" 22 | 23 | let create_named_exn s = 24 | if String.is_empty s 25 | then failwith "Named_or_unnamed.create_named_exn got an empty string"; 26 | if String.mem s '\x00' 27 | then failwith "Named_or_unnamed.create_named_exn: string must not contain \\x00"; 28 | s 29 | ;; 30 | 31 | let to_string (t : t) : string = t 32 | end 33 | 34 | module Statement_name = Named_or_unnamed 35 | module Portal_name = Named_or_unnamed 36 | module Notification_channel = String 37 | -------------------------------------------------------------------------------- /protocol/types.mli: -------------------------------------------------------------------------------- 1 | open Core 2 | 3 | type backend_key = 4 | { pid : int 5 | (** The type of [pid] is intentionally [int] instead of [Core.Pid.t]. 6 | 7 | When the client is connected directly to a database cluster this field is expected 8 | to be a Linux PID, in which case the assumptions in internal validation of 9 | [Core.Pid] are valid: in particular, that a PID must be a positive number. 10 | 11 | In the case that a client is connected to a 12 | {{:https://www.pgbouncer.org/} pgbouncer} proxy the PID supplied (and tracked) by 13 | the proxy is random. This is done to allow cancel requests per proxied client. 14 | 15 | The postgres protocol spec states that this field is "Int32: The process ID of this 16 | backend", so it's a little unclear as to whether or not pgbouncer's behaviour is a 17 | violation of the spec. In any case, it is what it is, and this field is an int 18 | rather than a [Pid.t] in order to support connecting via pgbouncer. 19 | 20 | This may also be true of other proxies or similar software. *) 21 | ; secret : int 22 | } 23 | [@@deriving sexp_of] 24 | 25 | module type Named_or_unnamed = sig 26 | type t 27 | 28 | val unnamed : t 29 | 30 | (** The provided string must be nonempty and not contain nulls. *) 31 | val create_named_exn : string -> t 32 | 33 | (** [to_string unnamed = ""] *) 34 | val to_string : t -> string 35 | end 36 | 37 | module Statement_name : Named_or_unnamed 38 | module Portal_name : Named_or_unnamed 39 | 40 | module Notification_channel : sig 41 | type t [@@deriving sexp_of] 42 | 43 | include Stringable with type t := t 44 | include Hashable with type t := t 45 | end 46 | -------------------------------------------------------------------------------- /src/command_complete.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | type t = 4 | { tag : string 5 | ; rows : int option 6 | } 7 | [@@deriving fields ~getters ~iterators:create, sexp_of] 8 | 9 | let create = Fields.create 10 | let empty = { tag = ""; rows = None } 11 | 12 | module type Public = Command_complete_intf.Public with type t = t 13 | -------------------------------------------------------------------------------- /src/command_complete.mli: -------------------------------------------------------------------------------- 1 | include Command_complete_intf.Command_complete (** @inline *) 2 | -------------------------------------------------------------------------------- /src/command_complete_intf.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | module type Public = sig 4 | type t [@@deriving sexp_of] 5 | 6 | (** [tag] is a short description of the command that was completed, like "SELECT", 7 | "BEGIN", "INSERT" or "CREATE TABLE", see description of the CommandComplete message 8 | in https://www.postgresql.org/docs/current/protocol-message-formats.html *) 9 | val tag : t -> string 10 | 11 | (** [rows] returns the affected rows count for the completed command, if available and 12 | was provided in the CommandComplete message *) 13 | val rows : t -> int option 14 | end 15 | 16 | module type Command_complete = sig 17 | include Public 18 | 19 | val create : tag:string -> rows:int option -> t 20 | 21 | (** The value we return for the empty query *) 22 | val empty : t 23 | 24 | module type Public = Public with type t = t 25 | end 26 | -------------------------------------------------------------------------------- /src/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name postgres_async) 3 | (public_name postgres_async) 4 | (libraries async async_ssl async_unix core_kernel.bus core 5 | async_kernel.eager_deferred core_kernel.iobuf postgres_async_protocol) 6 | (preprocess 7 | (pps ppx_jane))) 8 | -------------------------------------------------------------------------------- /src/message_reading_intf.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | module type S = sig 5 | type t 6 | 7 | type 'a handle_message_result = 8 | | Stop of 'a 9 | | Continue 10 | | Protocol_error of Pgasync_error.t 11 | 12 | (** [read_messages] and will handle and dispatch the three asynchronous message types 13 | for you; you should never see them. 14 | 15 | [handle_message] is given a message type constructor, and an iobuf windowed on the 16 | payload of the message (that is, the message-type-specific bytes; the window does 17 | not include the type or message length header, they have already been consumed). 18 | 19 | [handle_message] must consume (as in [Iobuf.Consume]) all of the bytes of the 20 | message. *) 21 | type 'a handle_message := 22 | Postgres_async_protocol.Backend.constructor -> (read, Iobuf.seek) Iobuf.t -> 'a 23 | 24 | type 'a read_messages_result = 25 | | Connection_closed of Pgasync_error.t 26 | | Done of 'a 27 | 28 | (** NoticeResponse, and ParameterStatus and NotificationResponse are 'asynchronous 29 | messages' and are not associated with a specific request-response conversation. They 30 | can happen at any time. 31 | 32 | [read_messages] handles these for you, and does not show them to your 33 | [handle_message] callback. *) 34 | val read_messages 35 | : ?pushback:(unit -> unit Deferred.t) 36 | -> t 37 | -> handle_message:'a handle_message_result handle_message 38 | -> 'a read_messages_result Deferred.t 39 | 40 | (** See [read_messages], except [handle_message] returns the pushback instead of a 41 | separate function *) 42 | val read_messages' 43 | : t 44 | -> handle_message:'a handle_message_result Deferred.t handle_message 45 | -> 'a read_messages_result Deferred.t 46 | 47 | (** If a message arrives while no request-response conversation (query or otherwise) is 48 | going on, use [consume_one_asynchronous_message] to eat it. 49 | 50 | If the message is one of those asynchronous messages, it will be handled. If it is 51 | some other message type, that is a protocol error and the connection will be closed. 52 | If the reader is actually at EOF, the connection will be closed with an error. *) 53 | val consume_one_asynchronous_message : t -> unit read_messages_result Deferred.t 54 | end 55 | -------------------------------------------------------------------------------- /src/or_pgasync_error.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | 3 | type 'a t = ('a, Pgasync_error.t) Result.t [@@deriving sexp_of] 4 | 5 | let to_or_error = Result.map_error ~f:Pgasync_error.to_error 6 | 7 | let ok_exn = function 8 | | Ok x -> x 9 | | Error e -> Pgasync_error.raise e 10 | ;; 11 | 12 | let error_s ?error_code s = Error (Pgasync_error.create_s ?error_code s) 13 | let error_string ?error_code s : _ t = Error (Pgasync_error.of_string ?error_code s) 14 | let errorf ?error_code fmt = ksprintf (error_string ?error_code) fmt 15 | let of_exn ?error_code e = Error (Pgasync_error.of_exn ?error_code e) 16 | -------------------------------------------------------------------------------- /src/or_pgasync_error.mli: -------------------------------------------------------------------------------- 1 | open Core 2 | 3 | type 'a t = ('a, Pgasync_error.t) Result.t [@@deriving sexp_of] 4 | 5 | val to_or_error : 'a t -> 'a Or_error.t 6 | val ok_exn : 'a t -> 'a 7 | val error_s : ?error_code:Pgasync_error.Sqlstate.t -> Sexp.t -> _ t 8 | val error_string : ?error_code:Pgasync_error.Sqlstate.t -> string -> _ t 9 | val of_exn : ?error_code:Pgasync_error.Sqlstate.t -> Exn.t -> _ t 10 | 11 | val errorf 12 | : ?error_code:Pgasync_error.Sqlstate.t 13 | -> ('a, unit, string, 'b t) format4 14 | -> 'a 15 | -------------------------------------------------------------------------------- /src/pgasync_error.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | module ErrorResponse = Postgres_async_protocol.Backend.ErrorResponse 3 | module Postgres_field = Postgres_async_protocol.Backend.Error_or_notice_field 4 | 5 | (* The MLI introduces the [Pgasync_error] type; it's our place to store the generic 6 | error, and the error code _if_ we know it. 7 | 8 | Now, on the subject of error handling more generally, there are a few axes on which 9 | one must make a decision here. 10 | 11 | API 12 | --- 13 | First and foremost, we need to decide whether or not we should expose a single api, 14 | where all the result types are [Or_pgasync_error.t], or should we expose a more 15 | convenient API that uses [Or_error.t], and an expert API that uses 16 | [Or_pgasync_error.t]? We do the latter, because the vast majority of our users don't 17 | care for the [postgres_error_code] and do use [Error.t] everywhere else. 18 | 19 | Internal storage of connection-killing errors 20 | --------------------------------------------- 21 | But more subtly, once you've agreed to provide [postgres_error_code], you need to be 22 | very careful about when you provide it. If you return an error from [query] with 23 | [postgres_error_code = Some _] in it, then the user will quite reasonably assume that 24 | the query failed with that error. But that might not be true: if the connection 25 | previously asynchronously failed due to some error (say, the connection closed), 26 | we're going to return that error to the user. 27 | 28 | This is a problem, because the error that closed the connection might imply that 29 | a specific thing is wrong with their query, when actually it is not. I don't have 30 | a great example, but suppose that there existed an error code relating to replication 31 | that might cause the backend to die. 32 | 33 | If the backend asynchronously dies, we'll close [t] and stash the reason it closed 34 | inside [t.state]. If someone then tries to use [query t], we'll fetch the reason 35 | out of [t.state], and return that as an error. But if we did this naively, then it 36 | would look to the user like their specific query conflicted with replication, and 37 | they might retry it, when actually that was unrelated. 38 | 39 | The important thing here is: we should only return a [postgres_error_code] that has 40 | semantic meaning to the user _if_ it relates to the operation they just performed, 41 | not some operation that happened previously. 42 | 43 | You could imagine two different ways of trying to achieve this: 44 | 45 | + Only ever stash [Error.t]s inside [state] (or similar). This ensures we'll never 46 | fetch a [postgres_error_code] out of storage and give it to the user. 47 | 48 | But this is tricky, because now we have a mixture of different error types in the 49 | library, which is super annoying and messy, and you have to be careful about when you 50 | use one vs. the other, and you have to keep converting back and forth. 51 | 52 | Furthermore, you need to be careful that a refactor cause an error relating to 53 | a specific query to be stashed and immediately retrieved, as that might erase the 54 | error code, or passed through some generic function that erases 55 | 56 | + Use [Pgasync_error.t] everywhere within the library, but before every operation 57 | like [query], examine [t.state] to see if the connection is already dead, and make 58 | sure that when you return the stashed [Pgasync_error.t], you erase the code from it 59 | (and tag it with a message like "query issued on closed connection; original error 60 | was foo" too). 61 | 62 | We go with the latter. 63 | 64 | This is precisely achieved by the first line of [parse_and_start_executing_query]. 65 | 66 | We're using the simplifying argument that even if an error gets stashed in [t.state] 67 | and then returned to the user at the end of [query], that error _wasn't_ there when 68 | the query was started, so it's reasonable to associate it with the [query], which 69 | sounds fine. *) 70 | 71 | module Sqlstate = struct 72 | type t = string [@@deriving compare, equal, hash, sexp_of] 73 | 74 | let cardinality_violation = "21000" 75 | let invalid_authorization_specification = "28000" 76 | let invalid_password = "28P01" 77 | let connection_exception = "08000" 78 | let sqlclient_unable_to_establish_sqlconnection = "08001" 79 | let connection_does_not_exist = "08003" 80 | let sqlserver_rejected_establishment_of_sqlconnection = "08004" 81 | let connection_failure = "08006" 82 | let protocol_violation = "08P01" 83 | let object_not_in_prerequisite_state = "55000" 84 | let undefined_object = "42704" 85 | let wrong_object_type = "42809" 86 | let syntax_error = "42601" 87 | end 88 | 89 | type t = 90 | { error : Error.t 91 | ; server_error : ErrorResponse.t option 92 | } 93 | 94 | let sexp_of_t t = [%sexp (t.error : Error.t)] 95 | 96 | let of_error ?error_code e = 97 | { error = e 98 | ; server_error = 99 | Option.map error_code ~f:(fun error_code : ErrorResponse.t -> 100 | { error_code; all_fields = [] }) 101 | } 102 | ;; 103 | 104 | let of_exn ?error_code e = of_error ?error_code (Error.of_exn e) 105 | let of_string ?error_code s = of_error ?error_code (Error.of_string s) 106 | let create_s ?error_code s = of_error ?error_code (Error.create_s s) 107 | 108 | let of_error_response (error_response : ErrorResponse.t) = 109 | let error = 110 | (* We omit some of the particularly noisy and uninteresting fields from the error 111 | message that will be displayed to users. 112 | 113 | Note that as-per [ErrorResponse.t]'s docstring, [Code] is included in this 114 | list. *) 115 | let interesting_fields = 116 | List.filter error_response.all_fields ~f:(fun (field, value) -> 117 | match field with 118 | | File | Line | Routine | Severity_non_localised -> false 119 | | Severity -> 120 | (* ERROR is the normal case for an error message, so just omit it *) 121 | String.( <> ) value "ERROR" 122 | | _ -> true) 123 | in 124 | Error.create_s [%sexp (interesting_fields : (Postgres_field.t * string) list)] 125 | in 126 | { error; server_error = Some error_response } 127 | ;; 128 | 129 | let tag t ~tag = { t with error = Error.tag t.error ~tag } 130 | let to_error t = t.error 131 | 132 | let postgres_error_code t = 133 | match t.server_error with 134 | | None -> None 135 | | Some { error_code; _ } -> Some error_code 136 | ;; 137 | 138 | let postgres_field t field = 139 | match t.server_error with 140 | | None -> None 141 | | Some { all_fields; _ } -> 142 | List.Assoc.find all_fields field ~equal:[%equal: Postgres_field.t] 143 | ;; 144 | 145 | let raise t = Error.raise t.error 146 | 147 | (* Truncate queries longer than this *) 148 | let max_query_length = ref 2048 149 | let max_parameter_length = ref 128 150 | let max_parameters = ref 16 151 | 152 | let set_error_reporting_limits 153 | ?(query_length = 2048) 154 | ?(parameter_length = 128) 155 | ?(parameters = 16) 156 | () 157 | = 158 | max_query_length := query_length; 159 | max_parameter_length := parameter_length; 160 | max_parameters := parameters 161 | ;; 162 | 163 | let query_tag ?parameters ~query_string t = 164 | lazy 165 | (let position = 166 | (* Postgres reports 1-based position *) 167 | Option.value_map ~default:1 ~f:Int.of_string (postgres_field t Position) - 1 168 | in 169 | (* We want to display the long query like this: 170 | 171 | [error position] 172 | 173 | where combined length of lead + prefix + suffix is less than 174 | max_query_length *) 175 | let half = !max_query_length / 2 in 176 | let prefix_start = max 0 (position - half) in 177 | let query_length = String.length query_string in 178 | let suffix_end = min query_length (position + half) in 179 | let prefix = 180 | (* end position 0 means whole string, we dont want this *) 181 | if position = 0 then "" else String.slice query_string prefix_start position 182 | in 183 | let suffix = String.slice query_string position suffix_end in 184 | let query = 185 | String.concat 186 | ~sep:"" 187 | [ (if prefix_start > 0 then "... " else "") 188 | ; prefix 189 | ; suffix 190 | ; (if suffix_end < query_length then " ..." else "") 191 | ] 192 | in 193 | let parameters = 194 | (* Limit the number of parameters *) 195 | Option.bind parameters ~f:(fun parameters -> 196 | if Array.is_empty parameters 197 | then None 198 | else if Array.length parameters > !max_parameters 199 | then 200 | Some 201 | (Array.init !max_parameters ~f:(fun idx -> 202 | if idx = !max_parameters - 1 203 | then ( 204 | let num_omitted = Array.length parameters - !max_parameters in 205 | Some [%string "remaining %{num_omitted#Int} parameter(s) omitted"]) 206 | else parameters.(idx))) 207 | else Some parameters) 208 | in 209 | let parameters = 210 | (* Limit the length of parameters *) 211 | Option.map parameters ~f:(fun parameters -> 212 | Array.map 213 | parameters 214 | ~f: 215 | (Option.map ~f:(fun param -> 216 | if String.length param > !max_parameter_length 217 | then String.prefix param !max_parameter_length ^ "..." 218 | else param))) 219 | in 220 | [%message (query : string) (parameters : (string option array option[@sexp.option]))]) 221 | ;; 222 | 223 | let tag_by_query ?parameters ~query_string t = 224 | { t with error = Error.tag_s_lazy t.error ~tag:(query_tag ?parameters ~query_string t) } 225 | ;; 226 | -------------------------------------------------------------------------------- /src/pgasync_error.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | module ErrorResponse = Postgres_async_protocol.Backend.ErrorResponse 3 | module Postgres_field = Postgres_async_protocol.Backend.Error_or_notice_field 4 | 5 | module Sqlstate : sig 6 | (** PostgreSQL Error Codes. 7 | 8 | Excerpt from 9 | {{:https://www.postgresql.org/docs/current/errcodes-appendix.html} PostgreSQL Error 10 | Codes}: 11 | 12 | "All messages emitted by the PostgreSQL server are assigned five-character error 13 | codes that follow the SQL standard's conventions for “SQLSTATE” codes. Applications 14 | that need to know which error condition has occurred should usually test the error 15 | code, rather than looking at the textual error message. The error codes are less 16 | likely to change across PostgreSQL releases, and also are not subject to change due 17 | to localization of error messages. Note that some, but not all, of the error codes 18 | produced by PostgreSQL are defined by the SQL standard; some additional error codes 19 | for conditions not defined by the standard have been invented or borrowed from other 20 | databases." 21 | 22 | We replicate some of these error codes here for [Postgres_async] to report in its 23 | own error codes when appropriate. 24 | 25 | See also: 26 | {{:https://github.com/postgres/postgres/blob/master/src/backend/utils/errcodes.txt} 27 | src/backend/utils/errcodes.txt} *) 28 | 29 | type t = private string [@@deriving compare, equal, hash, sexp_of] 30 | 31 | val cardinality_violation : t 32 | val connection_exception : t 33 | val sqlclient_unable_to_establish_sqlconnection : t 34 | val connection_does_not_exist : t 35 | val sqlserver_rejected_establishment_of_sqlconnection : t 36 | val connection_failure : t 37 | val protocol_violation : t 38 | val invalid_password : t 39 | val invalid_authorization_specification : t 40 | val object_not_in_prerequisite_state : t 41 | val undefined_object : t 42 | val wrong_object_type : t 43 | val syntax_error : t 44 | end 45 | 46 | type t [@@deriving sexp_of] 47 | 48 | val of_error : ?error_code:Sqlstate.t -> Error.t -> t 49 | val of_exn : ?error_code:Sqlstate.t -> exn -> t 50 | val of_string : ?error_code:Sqlstate.t -> string -> t 51 | val create_s : ?error_code:Sqlstate.t -> Sexp.t -> t 52 | val of_error_response : ErrorResponse.t -> t 53 | val tag : t -> tag:string -> t 54 | val to_error : t -> Error.t 55 | 56 | (** This is the SQLSTATE *) 57 | val postgres_error_code : t -> string option 58 | 59 | val postgres_field : t -> Postgres_field.t -> string option 60 | val raise : t -> _ 61 | val max_query_length : int ref 62 | val max_parameter_length : int ref 63 | val max_parameters : int ref 64 | 65 | val set_error_reporting_limits 66 | : ?query_length:int 67 | -> ?parameter_length:int 68 | -> ?parameters:int 69 | -> unit 70 | -> unit 71 | 72 | val tag_by_query : ?parameters:string option array -> query_string:string -> t -> t 73 | -------------------------------------------------------------------------------- /src/postgres_async.mli: -------------------------------------------------------------------------------- 1 | include Postgres_async_intf.Postgres_async (** @inline *) 2 | -------------------------------------------------------------------------------- /src/postgres_async_intf.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | module Protocol = Postgres_async_protocol 4 | module Types = Protocol.Types 5 | module Column_metadata = Protocol.Column_metadata 6 | 7 | module type S = sig 8 | (** In order to provide an [Expert] interface that uses [Pgasync_error.t] to represent 9 | its errors, alongside a normal interface that just uses Core's [Error.t], the 10 | interface is defined in this [module type S], with a type [error] that we erase when 11 | we include it in [postgres_async.mli]. *) 12 | type error 13 | 14 | type command_complete 15 | type t [@@deriving sexp_of] 16 | 17 | (** [gss_krb_token] will be sent in response to a server's request to initiate GSSAPI 18 | authentication. We don't support GSS/SSPI authentication that requires multiple 19 | steps; if the server sends us a "GSSContinue" message in response to 20 | [gss_krb_token], login will fail. Kerberos should not require this. 21 | 22 | [ssl_mode] defaults to [Ssl_mode.Disable]. 23 | 24 | [buffer_age_limit] sets the age limit on the outgoing Writer.t. The default limit is 25 | set to [`Unlimited] to avoid application crashes when the database is loaded. 26 | Default age limit on the Writer.t is 2 minutes, and you might want to use that value 27 | to preserve the existing behavior. 28 | 29 | [buffer_byte_limit] is only used during a COPY---it pauses inserting more rows and 30 | flushes the entire Writer.t if its buffer contains this many bytes or more. *) 31 | val connect 32 | : ?interrupt:unit Deferred.t 33 | -> ?ssl_mode:Ssl_mode.t 34 | -> ?server:_ Tcp.Where_to_connect.t 35 | -> ?user:string 36 | -> ?password:string 37 | -> ?gss_krb_token:string 38 | -> ?buffer_age_limit:Writer.buffer_age_limit 39 | -> ?buffer_byte_limit:Byte_units.t 40 | -> ?max_message_length:Byte_units.t 41 | -> database:string 42 | -> ?replication:string 43 | -> unit 44 | -> (t, error) Result.t Deferred.t 45 | 46 | (** [close] returns an error if there were any problems gracefully tearing down the 47 | connection. For sure, when it is determined, the connection is gone. 48 | 49 | [try_cancel_statement_before_close] defaults to false. If set to true, [close] will 50 | first attempt to cancel any query in progress on [t] before closing the connection, 51 | which provides more 'graceful' closing behavior. *) 52 | val close 53 | : ?try_cancel_statement_before_close:bool 54 | -> t 55 | -> (unit, error) Result.t Deferred.t 56 | 57 | val close_finished : t -> (unit, error) Result.t Deferred.t 58 | 59 | type state = 60 | | Open 61 | | Closing 62 | | Failed of 63 | { error : error 64 | ; resources_released : bool 65 | } 66 | | Closed_gracefully 67 | [@@deriving sexp_of] 68 | 69 | val status : t -> state 70 | 71 | val with_connection 72 | : ?interrupt:unit Deferred.t 73 | -> ?ssl_mode:Ssl_mode.t 74 | -> ?server:_ Tcp.Where_to_connect.t 75 | -> ?user:string 76 | -> ?password:string 77 | -> ?gss_krb_token:string 78 | -> ?buffer_age_limit:Async_unix.Writer.buffer_age_limit 79 | -> ?buffer_byte_limit:Byte_units.t 80 | -> ?max_message_length:Byte_units.t 81 | -> ?try_cancel_statement_before_close:bool 82 | -> database:string 83 | -> ?replication:string 84 | -> on_handler_exception:[ `Raise ] 85 | -> (t -> 'res Deferred.t) 86 | -> ('res, error) Result.t Deferred.t 87 | 88 | (** [handle_columns] can provide column information even if 0 rows are found. 89 | [handle_columns] is guaranteed to be called before the first invocation of 90 | [handle_row] *) 91 | type 'handle_row with_query_args := 92 | t 93 | -> ?parameters:string option array 94 | -> ?pushback:(unit -> unit Deferred.t) 95 | -> ?handle_columns:(Column_metadata.t iarray -> unit) 96 | -> string 97 | -> handle_row:'handle_row 98 | -> (command_complete, error) Result.t Deferred.t 99 | 100 | val query 101 | : (column_names:string iarray -> values:string option iarray -> unit) with_query_args 102 | 103 | val query_zero_copy : (Row_handle.t -> unit) with_query_args 104 | 105 | val query_expect_no_data 106 | : t 107 | -> ?parameters:string option array 108 | -> string 109 | -> (command_complete, error) Result.t Deferred.t 110 | 111 | type 'a feed_data_result = 112 | | Abort of { reason : string } 113 | | Wait of unit Deferred.t 114 | | Data of 'a 115 | | Finished 116 | 117 | val copy_in_raw 118 | : t 119 | -> ?parameters:string option array 120 | -> string 121 | -> feed_data:(unit -> string feed_data_result) 122 | -> (command_complete, error) Result.t Deferred.t 123 | 124 | (** Note that [table_name] and [column_names] must be escaped before calling 125 | [copy_in_rows]. *) 126 | val copy_in_rows 127 | : ?schema_name:string 128 | -> t 129 | -> table_name:string 130 | -> column_names:string list 131 | -> feed_data:(unit -> string option array feed_data_result) 132 | -> (command_complete, error) Result.t Deferred.t 133 | 134 | (** [listen_to_notifications] executes a query to subscribe you to notifications on 135 | [channel] (i.e., "LISTEN $channel") and stores [f] inside [t], calling it when the 136 | server sends us any such notifications. 137 | 138 | Calling it multiple times is fine: the "LISTEN" query is idempotent, and both 139 | callbacks will be stored in [t]. 140 | 141 | However, be careful. The interaction between postgres notifications and transactions 142 | is very subtle. Here are but some of the things you need to bear in mind: 143 | 144 | - LISTEN executed during a transaction that is rolled back has no effect. As such, 145 | if you're in the middle of a transaction when you call [listen_to_notifications] 146 | and then roll said transaction back, [f] will be stored in [t] but you will not 147 | actually receive any notifications from the server. 148 | 149 | - Notifications that happen while you are in a transaction are only delivered by the 150 | server after the end of the transaction. In particular, if you're doing a big 151 | [query] and you're pushing-back on the server, you're also potentially delaying 152 | delivery of notifications. 153 | 154 | - You need to pay attention to [close_finished] in case the server kicks you off. 155 | 156 | - The postgres protocol has no heartbeats. If the server disappears in a 157 | particularly bad way it might be a while before we notice. The empty query makes a 158 | rather effective heartbeat (i.e. [query_expect_no_data t ""]), but this is your 159 | responsibility if you want it. *) 160 | val listen_to_notifications 161 | : t 162 | -> channel:string 163 | -> f:(pid:Pid.t -> payload:string -> unit) 164 | -> (command_complete, error) Result.t Deferred.t 165 | end 166 | 167 | module type Pgasync_error = sig 168 | (** The type [Pgasync_error.t] is used for all errors in this library. 169 | 170 | [to_error] returns a regular [Core.Error.t] that fully describes what went wrong 171 | (including but not limited to a postgres error code, if we hae one), and if you only 172 | need to show the error to someone/raise/fail then using that completely suffices. 173 | 174 | _If_ the error due to an [ErrorResponse] from the server _and_ we successfully 175 | parsed an error code out of it, you can retrieve that via [postgres_error_code]. 176 | 177 | Note that either of those conditions might be false, e.g. the error might be due to 178 | TCP connection failure, so we certainly won't have any postgres error code; you 179 | shouldn't take the name of the type to mean "postgres error", rather "postgres async 180 | error"---any error this library can produce. *) 181 | type t = Pgasync_error.t [@@deriving sexp_of] 182 | 183 | module Postgres_field = Pgasync_error.Postgres_field 184 | 185 | val to_error : t -> Error.t 186 | val postgres_error_code : t -> string option 187 | val raise : t -> 'a 188 | val postgres_field : t -> Postgres_field.t -> string option 189 | 190 | (** Queries and query parameters are included in the errors. To keep the length of error 191 | messages in check, this information is abbreviated: 192 | - long queries are truncated 193 | - long parameter values are truncated 194 | - only a limited number of parameters are displayed 195 | 196 | You can change those global limits via [set_error_reporting_limits] *) 197 | val set_error_reporting_limits 198 | : ?query_length:int (** default: 2048 *) 199 | -> ?parameter_length:int (** default: 128 *) 200 | -> ?parameters:int (** default: 16 *) 201 | -> unit 202 | -> unit 203 | end 204 | 205 | module type Private = sig 206 | type t 207 | 208 | module Protocol = Protocol 209 | module Types = Types 210 | module Query_sequencer = Query_sequencer 211 | module Pgasync_error = Pgasync_error 212 | module Row_handle = Row_handle 213 | 214 | val pgasync_error_of_error : Error.t -> Pgasync_error.t 215 | val pq_cancel : t -> unit Or_error.t Deferred.t 216 | val runtime_parameters : t -> string String.Map.t 217 | val failed : t -> Pgasync_error.t -> unit 218 | val close_started : t -> unit Deferred.t 219 | 220 | (** Parses the server_version runtime parameters that we received from the server during 221 | connection startup. It should correspond to [server_version_num] on the server. *) 222 | val server_version : t -> (int, [ `Unknown ]) result 223 | 224 | module Simple_query_result : sig 225 | (** - Completed_with_no_warnings : everything worked as expected. The list will 226 | contain one Command_complete.t for each statement executed 227 | - Completed_with_warnings : the query ran successfully, but the query tried to do 228 | something unsupported client-side (e.g. COPY TO STDOUT). The list of errors will 229 | not be empty. 230 | - Connection_error : The underlying connection died at some point during query 231 | execution or parsing results. Query may or may not have taken effect on server. 232 | - Driver_error : Postgres_async received an unexpected protocol message from the 233 | server. Query may or may not have taken effect on server. 234 | - Failed : Got error from server, query did not take effect on server *) 235 | 236 | type t = 237 | | Completed_with_no_warnings of Command_complete.t list 238 | | Completed_with_warnings of (Command_complete.t list * Error.t list) 239 | | Failed of Pgasync_error.t 240 | | Connection_error of Pgasync_error.t 241 | | Driver_error of Pgasync_error.t 242 | 243 | val to_or_pgasync_error : t -> Command_complete.t list Or_pgasync_error.t 244 | end 245 | 246 | (** Executes a query according to the Postgres Simple Query protocol. As specified in 247 | the protocol, multiple commands can be chained together using semicolons and 248 | executed in one operation. If not already in a transaction, the query is treated as 249 | transaction and all commands within it are executed atomically. 250 | 251 | [handle_columns] is called on each column row description message. Note that with 252 | simple_query, it's possible that multiple row description messages are sent (as a 253 | simple query can contain multiple statements that return rows). 254 | 255 | [handle_row] is called for every row returned by the query. 256 | 257 | If a [Pgasync_error.t] is returned, this indicates that the connection was closed 258 | during processing. The query may or may not have successfully ran from the server's 259 | perspective. 260 | 261 | Queries containing COPY FROM STDIN will fail as this function does not support this 262 | operation. Queries containing COPY TO STDOUT will succeed, but the copydata will not 263 | be delivered to [handle_row], and a warning will appear in [Simple_query_status] *) 264 | val simple_query 265 | : ?pushback:(unit -> unit Deferred.t) 266 | -> ?handle_columns:(Column_metadata.t iarray -> unit) 267 | -> t 268 | -> string 269 | -> handle_row:(column_names:string iarray -> values:string option iarray -> unit) 270 | -> Simple_query_result.t Deferred.t 271 | 272 | (** Executes a query that should not return any rows using the Postgres Simple Query 273 | Protocol . As with [simple_query], multiple commands can be chained together using 274 | semicolons. Inherits the transaction behavior of [simple_query]. 275 | 276 | If any of the queries fails, or returns at least one row, an error will be returned 277 | and the transaction will be aborted. 278 | 279 | As with [simple_query], queries containing COPY FROM STDIN will fail. *) 280 | val execute_simple : t -> string -> Simple_query_result.t Deferred.t 281 | 282 | module Without_background_asynchronous_message_handling : sig 283 | type t 284 | 285 | (** Creates a TCP connection to the server, returning the Reader and Writer after the 286 | login message flow has been completed successfully. Will not have asynchronous 287 | message handling running *) 288 | val login_and_get_raw 289 | : ?interrupt:unit Deferred.t 290 | -> ?ssl_mode:Ssl_mode.t 291 | -> ?server:[< Socket.Address.t ] Tcp.Where_to_connect.t 292 | -> ?password:string 293 | -> ?gss_krb_token:string 294 | -> ?buffer_age_limit:[ `At_most of Core_private.Span_float.t | `Unlimited ] 295 | -> ?buffer_byte_limit:Byte_units.t 296 | -> ?max_message_length:Byte_units.t 297 | -> startup_message:Protocol.Frontend.StartupMessage.t 298 | -> unit 299 | -> t Or_pgasync_error.t Deferred.t 300 | 301 | val reader : t -> Reader.t 302 | val writer : t -> Writer.t 303 | val backend_key : t -> Types.backend_key option 304 | val runtime_parameters : t -> string String.Map.t 305 | val pq_cancel : t -> unit Deferred.Or_error.t 306 | end 307 | 308 | val iter_copy_out 309 | : t 310 | -> query_string:string 311 | -> f:((read, Iobuf.seek) Iobuf.t -> unit Deferred.t) 312 | -> Command_complete.t Or_pgasync_error.t Deferred.t 313 | 314 | (** Access to the underlying [Reader]. It is generally not safe to interact with this 315 | for a [t] that corresponds to a "normal" database connection, as this can put [t] 316 | into an unexpected state. Exposed only for use in specialized db connection use 317 | cases (e.g. replication) *) 318 | module Message_reading : Message_reading_intf.S with type t := t 319 | 320 | (** Access to the underlying [Writer] that sends bytes to the database. It is generally 321 | not safe to interact with the [Writer] here for a [t] that corresponds to a "normal" 322 | database connection, as this can put [t] into an unexpected state. Exposed only for 323 | use in specialized db connection use cases (e.g. replication) *) 324 | val writer : t -> Writer.t 325 | 326 | (** Internal state of [t], it is generally not safe to interact with the 327 | [Query_sequencer] here for a [t] that corresponds to a "normal" database connection, 328 | as this can put [t] into an unexpected state. Exposed only for use in specialized db 329 | connection use cases (e.g. replication) *) 330 | val query_sequencer : t -> Query_sequencer.t 331 | end 332 | 333 | module type Postgres_async = sig 334 | module Pgasync_error : Pgasync_error 335 | 336 | module Or_pgasync_error : sig 337 | type 'a t = ('a, Pgasync_error.t) Result.t [@@deriving sexp_of] 338 | 339 | val to_or_error : 'a t -> 'a Or_error.t 340 | val ok_exn : 'a t -> 'a 341 | end 342 | 343 | module Column_metadata : Column_metadata.Public 344 | module Command_complete : Command_complete.Public 345 | module Row_handle = Row_handle 346 | module Ssl_mode = Ssl_mode 347 | 348 | type t [@@deriving sexp_of] 349 | 350 | module type S := S with type t := t 351 | 352 | (** @open *) 353 | include S with type error := Error.t and type command_complete := unit 354 | 355 | (** The [Expert] module provides versions of all the same functions that instead return 356 | [Or_pgasync_error.t]s. 357 | 358 | Note that [t] and [Expert.t] is the same type, so you can mix-and-match depending on 359 | whether you want to try and inspect the error code of a specific failure or not. *) 360 | module Expert : S with type error := Pgasync_error.t and type command_complete := unit 361 | 362 | (** The [With_command_complete] module provides access to the contents of 363 | CommandComplete message, that includes a string tag that describes the command that 364 | was executed, and optional row count *) 365 | module With_command_complete : 366 | S with type error := Error.t and type command_complete := Command_complete.t 367 | 368 | module Expert_with_command_complete : 369 | S with type error := Pgasync_error.t and type command_complete := Command_complete.t 370 | 371 | module Private : Private with type t := t 372 | end 373 | -------------------------------------------------------------------------------- /src/query_sequencer.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Async 3 | open! Int.Replace_polymorphic_compare 4 | 5 | (* Why don't we implement this using [Sequencer] (and thereby get its nice exception 6 | handling, notably)? 7 | 8 | Well, you end up writing some loop like this 9 | 10 | {[ 11 | let rec run t = 12 | match Throttle.num_jobs_running s with 13 | | 0 -> when_idle etc. 14 | | _ -> 15 | (* wait for job to finish. *) 16 | run t 17 | ]} 18 | 19 | the problem is that "wait for job to finish" is hard. You can enqueue a job in the 20 | sequencer, or use [Throttle.prior_jobs_done], but both of those things modify 21 | [Throttle.num_jobs_running] and there are no guarantees about the race between 22 | [num_jobs_running] changing back to [0] and the result of the dummy job you enqueued 23 | becoming determined. In practice async-callbacks waiting on dummy job's deferred run 24 | before [num_jobs_running] decreases, and you end up in an infinite loop. 25 | 26 | We could probably hack around that with [Scheduler.yield_until_no_jobs_remain ()] or 27 | something but besides being slow, I claim this ultimately ends up being far more 28 | complicated than that which you see below. *) 29 | 30 | type job = 31 | | Job : 32 | { start : unit Ivar.t 33 | ; finished : 'a Deferred.t 34 | } 35 | -> job 36 | [@@deriving sexp_of] 37 | 38 | type when_idle_next_step = 39 | | Call_me_when_idle_again 40 | | Finished 41 | 42 | type t = 43 | { jobs_waiting : job Queue.t 44 | ; any_work_added : (unit, read_write) Bvar.t 45 | ; mutable when_idle : (unit -> when_idle_next_step Deferred.t) option 46 | } 47 | [@@deriving sexp_of] 48 | 49 | let rec run t = 50 | let%bind () = 51 | match Queue.dequeue t.jobs_waiting with 52 | | Some (Job { start; finished }) -> 53 | Ivar.fill_exn start (); 54 | let%bind _ = finished in 55 | return () 56 | | None -> 57 | (match t.when_idle with 58 | | None -> Bvar.wait t.any_work_added 59 | | Some func -> 60 | (match%bind func () with 61 | | Finished -> 62 | t.when_idle <- None; 63 | return () 64 | | Call_me_when_idle_again -> return ())) 65 | in 66 | run t 67 | ;; 68 | 69 | let create () = 70 | let t = 71 | { jobs_waiting = Queue.create (); any_work_added = Bvar.create (); when_idle = None } 72 | in 73 | don't_wait_for (run t); 74 | t 75 | ;; 76 | 77 | let enqueue t job : _ Deferred.t = 78 | let start = Ivar.create () in 79 | let finished = 80 | let%bind () = Ivar.read start in 81 | job () 82 | in 83 | Queue.enqueue t.jobs_waiting (Job { start; finished }); 84 | Bvar.broadcast t.any_work_added (); 85 | finished 86 | ;; 87 | 88 | let when_idle t callback = 89 | match t.when_idle with 90 | | Some _ -> failwith "Query_scheduler.when_idle: already have a callback" 91 | | None -> 92 | t.when_idle <- Some callback; 93 | Bvar.broadcast t.any_work_added () 94 | ;; 95 | 96 | let rec other_jobs_are_waiting t = 97 | match Queue.is_empty t.jobs_waiting with 98 | | false -> return () 99 | | true -> 100 | let%bind () = Bvar.wait t.any_work_added in 101 | other_jobs_are_waiting t 102 | ;; 103 | -------------------------------------------------------------------------------- /src/query_sequencer.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | (** This module is like [Sequencer], except it provides the [when_idle] function. 5 | 6 | [when_idle] allows you to perform some task while the sequencer is empty (i.e., no job 7 | running and nothing waiting to start). Your callback is called repeatedly while the 8 | queue is empty. 9 | 10 | Jobs will not be permitted to start until the [when_idle] callback returns, so that 11 | they may both use some shared resource without stepping on each other. Your callback 12 | should use [other_jobs_are_waiting] to know when is the right time to interrupt doing 13 | whatever it is doing and return. 14 | 15 | If it returns [Call_me_when_idle_again] early (i.e., before [other_jobs_are_waiting] 16 | is determined), then it will be immediately called again. 17 | 18 | If it returns [Finished], the when-idle callback will be deleted from [t]. 19 | 20 | Unlike [Sequencer], this module does nothing smart with exceptions; we provide no 21 | promises as to which monitor they will go to. This is because [Postgres_async] does 22 | not use/raise exceptions (bugs aside). *) 23 | 24 | type t [@@deriving sexp_of] 25 | 26 | val create : unit -> t 27 | val enqueue : t -> (unit -> 'a Deferred.t) -> 'a Deferred.t 28 | 29 | type when_idle_next_step = 30 | | Call_me_when_idle_again 31 | | Finished 32 | 33 | (** At most one 'when_idle' can be active at once; the previous callback must have 34 | returned [Finished] before a new one can be installed. *) 35 | val when_idle : t -> (unit -> when_idle_next_step Deferred.t) -> unit 36 | 37 | val other_jobs_are_waiting : t -> unit Deferred.t 38 | -------------------------------------------------------------------------------- /src/row_handle.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | module Column_metadata = Postgres_async_protocol.Column_metadata 3 | 4 | type seek = Iobuf.seek 5 | type no_seek = Iobuf.no_seek 6 | 7 | module Function = struct 8 | type t = 9 | | Create 10 | | Next 11 | | Foldi_or_iteri 12 | end 13 | 14 | type t = 15 | { columns : Column_metadata.t iarray [@globalized] 16 | ; datarow : (read, seek) Iobuf.t (** This is internal to the socket's reader *) 17 | ; mutable last_function : Function.t 18 | } 19 | 20 | module Private = struct 21 | let create columns ~datarow = 22 | let num_fields = Iobuf.Consume.uint16_be datarow in 23 | if num_fields <> Iarray.length columns 24 | then 25 | raise_s 26 | [%message 27 | "number of columns in DataRow message did not match RowDescription" 28 | ~row_description:(columns : Column_metadata.t iarray) 29 | (num_fields : int)] 30 | else { columns; datarow = Iobuf.read_only__local datarow; last_function = Create } 31 | ;; 32 | end 33 | 34 | let columns t = t.columns 35 | 36 | let unchecked_next { datarow; _ } ~(f : (read, no_seek) Iobuf.t option -> _) = 37 | let len = Iobuf.Consume.int32_be datarow in 38 | if len = -1 39 | then f None 40 | else ( 41 | let hi_bound = Iobuf.Hi_bound.window datarow in 42 | (* Narrow the window to just this one column. *) 43 | Iobuf.resize datarow ~len; 44 | let result = f (Some (Iobuf.no_seek__local datarow)) in 45 | Iobuf.bounded_flip_hi datarow hi_bound; 46 | (* Set the window to begin at the next column's length and end at the end of the row. *) 47 | result) 48 | ;; 49 | 50 | let next t ~(f : (read, no_seek) Iobuf.t option -> _) = 51 | let () = 52 | match t.last_function with 53 | | Create -> t.last_function <- Next 54 | | Next -> () 55 | | Foldi_or_iteri -> 56 | raise_s 57 | [%sexp 58 | "cannot call [Row_handle.next] after calling \ 59 | [Row_handle.foldi]/[Row_handle.iteri]"] 60 | in 61 | if Iobuf.is_empty t.datarow 62 | then None 63 | else 64 | Some 65 | (let len = Iobuf.Consume.int32_be t.datarow in 66 | if len = -1 67 | then f None 68 | else ( 69 | let hi_bound = Iobuf.Hi_bound.window t.datarow in 70 | (* Narrow the window to just this one column. *) 71 | Iobuf.resize t.datarow ~len; 72 | Exn.protect 73 | ~f:(fun () -> f (Some (Iobuf.no_seek__local t.datarow)) [@nontail]) 74 | ~finally:(fun () -> Iobuf.bounded_flip_hi t.datarow hi_bound))) 75 | ;; 76 | 77 | let foldi t ~init ~f = 78 | match t.last_function with 79 | | Next | Foldi_or_iteri -> 80 | raise_s 81 | [%sexp 82 | "cannot call [Row_handle.foldi]/[Row_handle.iteri] once any column has been \ 83 | consumed"] 84 | | Create -> 85 | t.last_function <- Foldi_or_iteri; 86 | Iarray.fold t.columns ~init ~f:(fun acc column -> 87 | unchecked_next t ~f:(fun value -> f ~column ~value acc) [@nontail]) 88 | [@nontail] 89 | ;; 90 | 91 | let iteri t ~f = 92 | foldi t ~init:() ~f:(fun ~column ~value () -> f ~column ~value [@nontail]) [@nontail] 93 | ;; 94 | -------------------------------------------------------------------------------- /src/row_handle.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | module Column_metadata := Postgres_async_protocol.Column_metadata 3 | 4 | type seek := Iobuf.seek 5 | type no_seek := Iobuf.no_seek 6 | 7 | (** The safest interfaces to this module are [iteri] and [foldi]. They should generally be 8 | preferred unless you have some compiler-aided check to ensure you are accessing 9 | columns in the correct order. 10 | 11 | If you are generating code, or otherwise doing something where you define both the 12 | query and handler together so that you can statically guarantee the number and order 13 | of result columns without even inspecting the [columns]: Use [unchecked_next] for best 14 | performance. 15 | 16 | If you can't use [foldi] or [iteri], but you don't have a nice static safety 17 | guarantee, there is [next], which at least ensures you don't attempt to read past the 18 | end of the row. This should probably only be used when you're hand-writing both the 19 | query and the row handler. 20 | 21 | Important note for users of [next] and [unchecked_next]: 22 | 23 | This row handle contains a reference to the raw internal reader buffer of the socket 24 | reader, so references to underlying buffers should not be held once functions return. 25 | 26 | This also means that consumers must consume *every* column of the row, otherwise 27 | [Postgres_async] will raise a protcol error. *) 28 | 29 | type t 30 | 31 | val columns : t -> Column_metadata.t iarray 32 | 33 | (** Consume the next column of the row. If there are no remaining columns, return [None]. 34 | 35 | If you need to seek in [value], use [Iobuf.sub_shared__local]. *) 36 | val next : t -> f:((read, no_seek) Iobuf.t option -> 'a) -> 'a option 37 | 38 | (** Like [next], but without the check that there are columns remaining in the row, nor a 39 | check to prevent you from calling [foldi]/[iteri]. 40 | 41 | If you can guarantee that you call [unchecked_next] exactly once per column, this is 42 | safe. *) 43 | val unchecked_next : t -> f:((read, no_seek) Iobuf.t option -> 'a) -> 'a 44 | 45 | (** [foldi] is a convenience alias for [unchecked_next] in [Array.iter (columns t)]. 46 | Calling [foldi] after any any other method (besides [columns]) is an error and will 47 | raise. *) 48 | val foldi 49 | : t 50 | -> init:'acc 51 | -> f:(column:Column_metadata.t -> value:(read, no_seek) Iobuf.t option -> 'acc -> 'acc) 52 | -> 'acc 53 | 54 | (** see [foldi] *) 55 | val iteri 56 | : t 57 | -> f:(column:Column_metadata.t -> value:(read, no_seek) Iobuf.t option -> unit) 58 | -> unit 59 | 60 | (** / **) 61 | 62 | module Private : sig 63 | val create : Column_metadata.t iarray -> datarow:([> read ], seek) Iobuf.t -> t 64 | end 65 | -------------------------------------------------------------------------------- /src/ssl_mode.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open! Int.Replace_polymorphic_compare 3 | 4 | type t = 5 | | Disable 6 | | Prefer 7 | | Require 8 | [@@deriving sexp_of, variants] 9 | 10 | let of_libpq_string string = 11 | match String.lowercase string with 12 | | "disable" -> Some Disable 13 | | "prefer" -> Some Prefer 14 | | "require" -> Some Require 15 | | _ -> None 16 | ;; 17 | 18 | let to_libpq_string t = Variants.to_name t |> String.lowercase 19 | -------------------------------------------------------------------------------- /src/ssl_mode.mli: -------------------------------------------------------------------------------- 1 | open Core 2 | open! Int.Replace_polymorphic_compare 3 | 4 | (** [t] is a subset of the types supported by the 'sslmode' parameter in libpq (documented 5 | at https://www.postgresql.org/docs/current/libpq-ssl.html). 6 | 7 | We don't currently support verifying certificate signatures, so there's nothing 8 | analogous to the "verify-ca" or "verify-full" options here. We don't distinguish 9 | between "allow" and "prefer" (they seem to exactly match in terms of behavior). 10 | 11 | Under the hood, [Prefer] and [Require] will both result in the very first message sent 12 | to the server being the [SSLRequest] message, instead of the [StartupMessage]. The 13 | server will respond with whether or not it is able to support SSL. 14 | 15 | Using [Require], [connect] will return an error if the server cannot support SSL. 16 | 17 | [Prefer] will SSL-wrap the connection if the server supports SSL, or will use a plain 18 | TCP connection if the server does not support SSL. 19 | 20 | [Disable] will always use the plain TCP connection, and will not send the [SSLRequest] 21 | message. *) 22 | 23 | type t = 24 | | Disable 25 | | Prefer 26 | | Require 27 | [@@deriving sexp_of] 28 | 29 | val to_libpq_string : t -> string 30 | val of_libpq_string : string -> t option 31 | -------------------------------------------------------------------------------- /src/string_escaping.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open! Int.Replace_polymorphic_compare 3 | 4 | let escape_identifier s = 5 | String.split s ~on:'.' 6 | |> List.map ~f:(fun s -> 7 | "\"" ^ String.substr_replace_all s ~pattern:"\"" ~with_:"\"\"" ^ "\"") 8 | |> String.concat ~sep:"." 9 | ;; 10 | 11 | (* temporary escape hatch in case we break someone's code *) 12 | let quote_table_name_requested = 13 | lazy (Option.is_some (Sys.getenv "POSTGRES_ASYNC_COPY_ESCAPE_NAMES")) 14 | ;; 15 | 16 | module Copy_in = struct 17 | let query ?schema_name ~table_name ~column_names () = 18 | let column_names = 19 | (if Lazy.force quote_table_name_requested 20 | then List.map column_names ~f:escape_identifier 21 | else column_names) 22 | |> String.concat ~sep:", " 23 | in 24 | let table_name = 25 | if Lazy.force quote_table_name_requested 26 | then escape_identifier table_name 27 | else table_name 28 | in 29 | let table_name = 30 | match schema_name with 31 | | None -> table_name 32 | | Some schema -> schema ^ "." ^ table_name 33 | in 34 | [%string 35 | "COPY %{table_name} ( %{column_names} ) FROM STDIN ( FORMAT text, DELIMITER '\t')"] 36 | ;; 37 | 38 | let special_escape char = 39 | match char with 40 | | '\n' -> Some 'n' 41 | | '\r' -> Some 'r' 42 | | '\t' -> Some 't' 43 | | '\\' -> Some '\\' 44 | | _ -> None 45 | ;; 46 | 47 | let is_special c = Option.is_some (special_escape c) 48 | 49 | let row_to_string row = 50 | let row = 51 | Array.map row ~f:(fun s -> 52 | match s with 53 | | None -> None 54 | | Some s -> Some (s, String.count s ~f:is_special)) 55 | in 56 | let total_size = 57 | Array.fold row ~init:0 ~f:(fun acc s -> 58 | match s with 59 | | None -> acc + 3 60 | | Some (s, specials) -> acc + String.length s + specials + 1) 61 | in 62 | let data = Bytes.create total_size in 63 | let pos = 64 | Array.fold row ~init:0 ~f:(fun pos s -> 65 | let pos = 66 | match s with 67 | | None -> 68 | Bytes.From_string.blit ~src:"\\N" ~src_pos:0 ~dst:data ~dst_pos:pos ~len:2; 69 | pos + 2 70 | | Some (s, 0) -> 71 | let len = String.length s in 72 | Bytes.From_string.blit ~src:s ~src_pos:0 ~dst:data ~dst_pos:pos ~len; 73 | pos + len 74 | | Some (s, _) -> 75 | String.fold s ~init:pos ~f:(fun pos char -> 76 | match special_escape char with 77 | | None -> 78 | Bytes.set data pos char; 79 | pos + 1 80 | | Some char -> 81 | Bytes.set data pos '\\'; 82 | Bytes.set data (pos + 1) char; 83 | pos + 2) 84 | in 85 | Bytes.set data pos '\t'; 86 | pos + 1) 87 | in 88 | assert (pos = Bytes.length data); 89 | Bytes.set data (pos - 1) '\n'; 90 | Bytes.unsafe_to_string ~no_mutation_while_string_reachable:data 91 | ;; 92 | end 93 | 94 | module Listen = struct 95 | let query ~channel = sprintf !"LISTEN %{escape_identifier}" channel 96 | end 97 | -------------------------------------------------------------------------------- /src/string_escaping.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | module Copy_in : sig 4 | (** Note that [schema_name], [table_name], and [column_names] must be escaped before 5 | calling [query]. *) 6 | val query 7 | : ?schema_name:string 8 | -> table_name:string 9 | -> column_names:string list 10 | -> unit 11 | -> string 12 | 13 | (** [row_to_string] includes the terminating '\n' *) 14 | val row_to_string : string option array -> string 15 | end 16 | 17 | module Listen : sig 18 | val query : channel:string -> string 19 | end 20 | 21 | (** No [escape_value] function is provided, because so far parameters have sufficed for 22 | putting values into query strings. *) 23 | -------------------------------------------------------------------------------- /test/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name postgres_async_tests) 3 | (libraries core async async_ssl core_unix expect_test_helpers_core 4 | core_unix.filename_unix core_kernel.iobuf postgres_async 5 | core_unix.signal_unix core_unix.sys_unix) 6 | (preprocess 7 | (pps ppx_jane))) 8 | -------------------------------------------------------------------------------- /test/harness.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | module Unix = Core_unix 3 | 4 | type t = 5 | { datadir : string 6 | ; server_pid : Pid.t 7 | ; socket_dir : string 8 | ; port : int 9 | } 10 | 11 | let postgres_bins = 12 | lazy 13 | (let candidates = [ "/usr/pgsql-12/bin"; "/usr/pgsql-13/bin" ] in 14 | List.find candidates ~f:(fun dir -> 15 | match Sys_unix.is_directory dir with 16 | | `Yes -> true 17 | | `No | `Unknown -> false) 18 | |> Option.value_exn ~message:"could not find a postgresql installation") 19 | ;; 20 | 21 | let pg_hba = 22 | [ "# TYPE DATABASE USER ADDRESS METHOD" 23 | ; "local all postgres trust" 24 | ; "host all postgres 127.0.0.1/32 trust" 25 | ; "local all +role_password_login md5" 26 | ; "host all +role_password_login 127.0.0.1/32 md5" 27 | ; "" 28 | ] 29 | |> String.concat ~sep:"\n" 30 | ;; 31 | 32 | (* unix sockets must be under ~100 characters long, so we create [socket_dir] in /tmp *) 33 | let socket_dir = 34 | lazy (Filename_unix.temp_dir ~in_dir:"/tmp" "postgres-async-test-harness" "") 35 | ;; 36 | 37 | let tempfiles_dir = Filename.temp_dir_name 38 | 39 | let fork_redirect_exec ~prog ~args ~stdouterr_file = 40 | match Unix.fork () with 41 | | `In_the_parent pid -> pid 42 | | `In_the_child -> 43 | Unix.dup2 ~src:stdouterr_file ~dst:Unix.stdout (); 44 | Unix.dup2 ~src:stdouterr_file ~dst:Unix.stderr (); 45 | never_returns (Unix.exec ~prog ~argv:(prog :: args) ()) 46 | ;; 47 | 48 | let create ?(extra_server_args = []) () = 49 | let postgres_output_filename, postgres_output = 50 | Unix.mkstemp (tempfiles_dir ^/ "postgres-output") 51 | in 52 | let datadir = Unix.mkdtemp (tempfiles_dir ^/ "postgres-datadir") in 53 | let get_postgres_output () = In_channel.read_all postgres_output_filename in 54 | (* Ask the OS to assign us a temporary port. *) 55 | let temp_socket = Unix.socket ~domain:PF_INET ~kind:SOCK_STREAM ~protocol:0 () in 56 | Unix.bind temp_socket ~addr:(ADDR_INET (Unix.Inet_addr.of_string "0.0.0.0", 0)); 57 | let port = 58 | match Unix.getsockname temp_socket with 59 | | ADDR_UNIX _ -> assert false 60 | | ADDR_INET (_, port) -> port 61 | in 62 | (* Postgres binds with SO_REUSEADDR. 63 | 64 | In order to allow postgres to use this port, we must set SO_REUSEADDR on 65 | [temp_socket] too. 66 | 67 | By binding the same port on [127.0.0.2] and _then_ switching REUSEADDR on 68 | [temp_socket2] off, it will not be possible to bind to [0.0.0.0:port], since such 69 | a bind will conflict with the [127.0.0.2:port] socket. In particular, another 70 | instance of this test will not re-use the port. But also other processes asking for 71 | [0.0.0.0:ephemeral] will not be given [port]. 72 | 73 | A process asking for [127.0.0.1:ephemeral] without REUSEADDR set be handed our port. 74 | 75 | A process asking for an ephemeral port on [127.0.0.1] with REUSEADDR set could 76 | possibly be handed our port. It's very unlikely that any process actually attempts 77 | this. *) 78 | let temp_socket2 = Unix.socket ~domain:PF_INET ~kind:SOCK_STREAM ~protocol:0 () in 79 | Unix.setsockopt temp_socket SO_REUSEADDR true; 80 | Unix.setsockopt temp_socket2 SO_REUSEADDR true; 81 | Unix.bind temp_socket2 ~addr:(ADDR_INET (Unix.Inet_addr.of_string "127.0.0.2", port)); 82 | Unix.setsockopt temp_socket2 SO_REUSEADDR false; 83 | let initdb = 84 | let prog = force postgres_bins ^/ "initdb" in 85 | fork_redirect_exec 86 | ~prog 87 | ~args:[ "-D"; datadir; "-N" (* no sync *); "-U"; "postgres" ] 88 | ~stdouterr_file:postgres_output 89 | in 90 | match Unix.waitpid_exn initdb with 91 | | exception exn -> 92 | print_endline (get_postgres_output ()); 93 | raise exn 94 | | () -> 95 | Out_channel.write_all (datadir ^/ "pg_hba.conf") ~data:pg_hba; 96 | let socket_dir = force socket_dir in 97 | let server_pid = 98 | let prog = force postgres_bins ^/ "postgres" in 99 | let args = 100 | [ "-D" 101 | ; datadir 102 | ; "-c" 103 | ; "listen_addresses=127.0.0.1" 104 | ; "-c" 105 | ; sprintf "port=%i" port 106 | ; "-c" 107 | ; "unix_socket_directories=" ^ socket_dir 108 | ; "-c" 109 | ; "logging_collector=false" (* log to stdout *) 110 | ] 111 | @ extra_server_args 112 | in 113 | fork_redirect_exec ~prog ~args ~stdouterr_file:postgres_output 114 | in 115 | at_exit (fun () -> 116 | (* SIGQUIT = 'immediate shutdown' *) 117 | (match Signal_unix.send Signal.quit (`Pid server_pid) with 118 | | `No_such_process -> eprintf "in at-exit handler, postgres was not alive?\n%!" 119 | | `Ok -> 120 | (match Unix.waitpid_exn server_pid with 121 | | () -> () 122 | | exception exn -> 123 | eprintf !"in at-exit handler, waiting for postgres failed: %{Exn}\n%!" exn)); 124 | let pid = 125 | Unix.fork_exec 126 | () 127 | ~prog:"rm" 128 | ~argv:[ "rm"; "-rf"; "--"; datadir; socket_dir; postgres_output_filename ] 129 | in 130 | Unix.waitpid_exn pid); 131 | let rec wait_for_postgres ~timeout = 132 | match timeout < 0 with 133 | | true -> 134 | print_endline (get_postgres_output ()); 135 | failwith "timeout waiting for postgres to start" 136 | | false -> 137 | (match Unix.wait_nohang (`Pid server_pid) with 138 | | Some (_, exit_or_signal) -> 139 | print_endline (In_channel.read_all postgres_output_filename); 140 | raise_s 141 | [%message 142 | "postgres terminated early" (exit_or_signal : Unix.Exit_or_signal.t)] 143 | | None -> 144 | let output = get_postgres_output () in 145 | (match String.is_substring output ~substring:"ready to accept connections" with 146 | | false -> 147 | ignore (Unix.nanosleep 0.1 : float); 148 | wait_for_postgres ~timeout:(timeout - 1) 149 | | true -> ())) 150 | in 151 | wait_for_postgres ~timeout:100; 152 | { server_pid; datadir; socket_dir; port } 153 | ;; 154 | 155 | let create_database { socket_dir; port; _ } name = 156 | let pid = 157 | Unix.fork_exec 158 | () 159 | ~prog:(force postgres_bins ^/ "psql") 160 | ~argv: 161 | [ "psql" 162 | ; "-qX" 163 | ; "-h" 164 | ; socket_dir 165 | ; "-p" 166 | ; Int.to_string port 167 | ; "-U" 168 | ; "postgres" 169 | ; "-d" 170 | ; "postgres" 171 | ; "--set" 172 | ; "ON_ERROR_STOP" 173 | ; "-c" 174 | ; "CREATE DATABASE " ^ name 175 | ] 176 | in 177 | Unix.waitpid_exn pid 178 | ;; 179 | 180 | let pg_hba_filename { datadir; _ } = datadir ^/ "pg_hba.conf" 181 | let pg_ident_filename { datadir; _ } = datadir ^/ "pg_ident.conf" 182 | let sighup_server { server_pid; _ } = Signal_unix.send_exn Signal.hup (`Pid server_pid) 183 | let unix_socket_path { socket_dir; port; _ } = sprintf "%s/.s.PGSQL.%i" socket_dir port 184 | let port { port; _ } = port 185 | 186 | open Async 187 | 188 | let where_to_connect t = Tcp.Where_to_connect.of_file (unix_socket_path t) 189 | 190 | let with_connection_exn t ?(user = "postgres") ~database func = 191 | match%bind 192 | Postgres_async.with_connection 193 | ~user 194 | ~server:(where_to_connect t) 195 | ~database 196 | ~on_handler_exception:`Raise 197 | func 198 | with 199 | | Ok () -> return () 200 | | Error err -> Error.raise err 201 | ;; 202 | -------------------------------------------------------------------------------- /test/harness.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | 3 | type t 4 | 5 | val create : ?extra_server_args:string list -> unit -> t 6 | val create_database : t -> string -> unit 7 | val pg_hba_filename : t -> string 8 | val pg_ident_filename : t -> string 9 | val sighup_server : t -> unit 10 | val unix_socket_path : t -> string 11 | val port : t -> int 12 | 13 | open! Async 14 | 15 | val where_to_connect : t -> Socket.Address.Unix.t Tcp.Where_to_connect.t 16 | 17 | val with_connection_exn 18 | : t 19 | -> ?user:string (* default: the super user, postgres. *) 20 | -> database:string 21 | -> (Postgres_async.t -> unit Deferred.t) 22 | -> unit Deferred.t 23 | -------------------------------------------------------------------------------- /test/postgres_async_tests.ml: -------------------------------------------------------------------------------- 1 | module Harness = Harness 2 | module Test_cancellation = Test_cancellation 3 | module Test_connect = Test_connect 4 | module Test_copy_in = Test_copy_in 5 | module Test_copy_out = Test_copy_out 6 | module Test_error_code = Test_error_code 7 | module Test_notify = Test_notify 8 | module Test_protocol_round_trip = Test_protocol_round_trip 9 | module Test_query = Test_query 10 | module Test_runtime_parameters = Test_runtime_parameters 11 | module Test_server_failure = Test_server_failure 12 | module Test_simple_query = Test_simple_query 13 | module Test_smoke = Test_smoke 14 | module Test_ssl = Test_ssl 15 | module Utils = Utils 16 | -------------------------------------------------------------------------------- /test/server-leaf_certificate.crt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/janestreet/postgres_async/55d83400bae826efb90def299675e558bf37420e/test/server-leaf_certificate.crt -------------------------------------------------------------------------------- /test/server-leaf_certificate.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDUjCCAjqgAwIBAgIIObUxzB8SUIYwDQYJKoZIhvcNAQELBQAwHDEaMBgGA1UE 3 | AwwRaW50ZXJtZWRpYXRlLWNhLTEwHhcNMjIwMTA1MTcwNjUzWhcNMzIwMTAzMTcw 4 | NjUzWjAaMRgwFgYDVQQDDA93d3cuZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEB 5 | AQUAA4IBDwAwggEKAoIBAQC1XFfcPdeBIrrNowpCAyrIGQ09OUTU3oQEoA+/dnqt 6 | nWq4AbCPoXsgRydBUNkTni0zPCBdrgvybBkDNBZ4s7vj1TvCpKlyU7buU4PykK0q 7 | cL8IwfdBGx+yWBrZeFvM9ZqQ5PQ5VoFUwwW29n/zF9umH2PCf0EbC/zGNzoQdn5+ 8 | i4pC0flqSKiDYBOvTFARYhUesXmzWFBxNiuILsCEjIj8GGNQlHYsqz9oyfaVr+sz 9 | YdwLchfN6FSjRbQI9rcZrmzIxhgW+wfqDx6OOetgIDGkimzWStIiBStyHRXhf8BD 10 | E16K26YIOrPUqMqLQzcolk7V0FzmYJnYYsfjjXwoUNnXAgMBAAGjgZkwgZYwHQYD 11 | VR0OBBYEFA0B2NO7bnptnrvQC0TDf5ExAQDeMA8GA1UdDwEB/wQFAwMHoAAwHQYD 12 | VR0RAQH/BBMwEYIPd3d3LmV4YW1wbGUuY29tMAwGA1UdEwEB/wQCMAAwHwYDVR0j 13 | BBgwFoAUaaoOQLj01naZmdfRSVJTpHmng+gwFgYDVR0lAQH/BAwwCgYIKwYBBQUH 14 | AwEwDQYJKoZIhvcNAQELBQADggEBAKzs9nrSOU36Kw7pskJTNETfORE/a5F8ygtR 15 | bArfGywdQAD+cE8i7bbBfWmC4ee1bwHp7R5XPrtFwII1SbUkxGBWxlWPfrvlTr5r 16 | 2lOpNf8icVbynSPXGFBiJbbbFlXMPH76nvs2hL2Z+VXGMx2u7/+8aOpe35CZFAHE 17 | Im9YA8ZuXuCUE2LqMkzVCMiDlI4mXcn8D2uNCapUu237jW48X4g4WcCxFWZxanbd 18 | k23G9IhRjkx1TYvtF08Ls9Fe2QNmWk6kpiKTJ9U0QYtlX/WDHQeoct5jwXs7TMga 19 | MrRAvLUrdSPvAeWWDCHn1rv8EvjJfVqrduN63ifPoWD4Y2YW1ck= 20 | -----END CERTIFICATE----- 21 | -----BEGIN CERTIFICATE----- 22 | MIIC1TCCAb2gAwIBAgIJAI6auMzwmLkOMA0GCSqGSIb3DQEBCwUAMBIxEDAOBgNV 23 | BAMMB3Rlc3QtY2EwHhcNMjIwMTA1MTcwNjUzWhcNMzIwMTAzMTcwNjUzWjAcMRow 24 | GAYDVQQDDBFpbnRlcm1lZGlhdGUtY2EtMTCCASIwDQYJKoZIhvcNAQEBBQADggEP 25 | ADCCAQoCggEBAOfsnpHv8kTKUDSpkHahmWi7hAeIQBeqpywoSCkusKEuA+jhwc1E 26 | Lw2Ug17LkACe+GhQsYDAqBUs9RhKWZNoQ/CnNjemZPfd0qc1qWTx/acyAnvh2405 27 | on4RAG0x3KZ90ZjCM+aZ7DD71J0nU/MqrzVrUHYJjgczZuojp0J1DUDfPKawTrnk 28 | tg12n0eNV67cPWVl83EElUOc7R63c50B2nyjekc4/XkueDKBWR3SFpM4XbCmG1mq 29 | cHaT+OZoLOmyDTx7/rkliUKEt8NupncdjO2rCOaoirQ56lpB7fP4Qbjen0C72p1t 30 | n+hjaJnkn/RrfR074xMAyaT/sEYtRtA05xECAwEAAaMkMCIwDwYDVR0PAQH/BAUD 31 | AwdGADAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBXKtm1chCz 32 | UwdJVAJ456BU6zpRsoROdwY8knd8qsl+dM4B7TQgr1w/JGkL0wtx5BJFDif09tMI 33 | UFzeb5ThDFFAQsiT996+fU7Cj2CV0NWpeoMNnOa6bZU2xOhNycTXPasdhUjg032U 34 | 3IQ0p3d4wAIz9q3TIVBjvryB01SaceZo9BcuIC1Cqot+cNuZLKZa4XQSxCKkN05i 35 | 7+Hv4ow7xbUrObiFKaMWkFDSJ7PHy/uKX2qlgyE4ttVpCYKWNyizvwncPGGi7UU8 36 | aHkrwrmCUnN8bBBdcLBuXq/xRAhcpWQFZt8B7JmPorzDFHouI/5DEbNWxJbQ1cii 37 | 3jBkJdoAsATi 38 | -----END CERTIFICATE----- 39 | -----BEGIN CERTIFICATE----- 40 | MIICyjCCAbKgAwIBAgIIRhLuf2tMDjUwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE 41 | AwwHdGVzdC1jYTAeFw0yMjAxMDUxNzA2NTNaFw0zMjAxMDMxNzA2NTNaMBIxEDAO 42 | BgNVBAMMB3Rlc3QtY2EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD1 43 | gdMx/krRQez7TUWkYQr2mkqwc0pngjxUujtiJBsmHujJQxTV9AjXqetDqOKIrl9y 44 | //jNK3/cNU+TS69/YlsZJNTZ/XMxbSR3xRbWIV4QTrjrIskRX/Cm3fMgha/H0Vpa 45 | AcHek4qZPN1B2PObP47yJDRKvbMotn4nimxaaxiP7GRWRSVWywn42GwIQOkU9O4E 46 | 3GbbI10U3inN/PAPx6Hgkaz/gb4uAOjq/s5YHo/LEgE2EUzP5Qltbtc1awK3trFu 47 | Aiu7Nv1kLDQaa+U7LCArCHeAad2tJpxF+CjbCiXseMxej6CQPPmAAm+t4iLlbuE4 48 | ntWWFsDYT057PdNQ8y8DAgMBAAGjJDAiMA8GA1UdDwEB/wQFAwMHRgAwDwYDVR0T 49 | AQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAiBszuZrGeKxzq1ORQmOlZAJ8 50 | XEfG0T0mBkix135u9t5Aenp/eJb7XUTxWfhHb/VbNumvnFzOFnMTyLwE+iGwgn47 51 | 8eA0TWX9xoxlP1kBeEa8jzjLw9LhPFeZ+PoqfqsGxED9OUYrFZZvSc/ZZBI6Q0GL 52 | DRnnmwLjMSRz1sr7iuy8LBv/QNgVKeIWq/d2MgWUvv1KyhcYKlKCFm60/vLAedvt 53 | WipObqApPYYlOeWYKBWV2YqlKBMEwPRnTwic7hlyUwCXVFdQk+JaPuZZvqX+YHQu 54 | +XhtjPz7ty07FW8zSALQvNb3mHZ9cuU1fxN3Kn+aDh5iA7JRFdqJXLEo0rofAg== 55 | -----END CERTIFICATE----- 56 | 57 | -------------------------------------------------------------------------------- /test/server-leaf_key.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQC1XFfcPdeBIrrN 3 | owpCAyrIGQ09OUTU3oQEoA+/dnqtnWq4AbCPoXsgRydBUNkTni0zPCBdrgvybBkD 4 | NBZ4s7vj1TvCpKlyU7buU4PykK0qcL8IwfdBGx+yWBrZeFvM9ZqQ5PQ5VoFUwwW2 5 | 9n/zF9umH2PCf0EbC/zGNzoQdn5+i4pC0flqSKiDYBOvTFARYhUesXmzWFBxNiuI 6 | LsCEjIj8GGNQlHYsqz9oyfaVr+szYdwLchfN6FSjRbQI9rcZrmzIxhgW+wfqDx6O 7 | OetgIDGkimzWStIiBStyHRXhf8BDE16K26YIOrPUqMqLQzcolk7V0FzmYJnYYsfj 8 | jXwoUNnXAgMBAAECggEADkMIdVbdo0rGX65vFSx0YrW2bB0C5y4abQl83SSyVmqd 9 | ixvPtScEBjbiOBYyA+nTI+2aA/pQXf99lg1nqPurwY/+rS2174b0YnIdDoGtd6I6 10 | VTRGC9yLFdzGbvL/XeZNTqpA9GAm0OVHOUSek9DJlc48Uf3s2KiM97XXGJbtsQJV 11 | 4af1EBD+ntZaOfS2b2yAUfdkUn2KMJYntmugbqJ11MFfOYxMqagUELy6AsYFZlvk 12 | 1A4wL+25pq2FMDO8YQmq9OUnyV+QA8nypT84TRoNaux0FXf6CCEhT11gcjV05/F8 13 | PPpI0AWLVnkN33gLAwucsZyCLBcr9CFwwsd/B0kr1QKBgQDvmRNoboWP+JN2XSt9 14 | Qosp3YQUsrvO/oxc9eW2YxYwqTmfczAdhnL29me5FAeei03D1T9K/707bROeNyFf 15 | wQVxdBtKXFtug9niC004mwIG94Ukkt/b8PDtPQ6y9kAakSWaEisCrMEusbgq18/M 16 | sHhiTnXsPTI8ZGeYbP4Z/aYYmwKBgQDBxqrGkc5/Gb1i9uliQcf5oa+9/hcAGRSo 17 | nhIeUdvu3fbpgtJoBTPe9dlN3mfFP7Ij0+odxYZN3rU6joZySQunD+8a3/IAn0yO 18 | YLn2//PxkqphVlNJmPE9zNkG7h1XoP2eenZSiz9VxwAf4EzC40fEPiKVY6BxJfcJ 19 | wqLxv10BdQKBgA0Ey0osXzOAdTrEOz22JOukbq3VPGE18ZiHf/DWF3mTaF8imiWw 20 | jYSfxOkIjpVtyk7uwl6n4Lde6Ob65eRXD52nimgS9qDdpzQiGxMNUSHhxylClclU 21 | oTKy056jxL3szxc9D3s4udJ4s6IYUeE0YYYt7zhj5tvjNMHSgkgVQTYrAoGAXP+Y 22 | 5HYD0dIrp3xy49pIPFFSA/AXX8+pr4c1kOGemRRkNQu5KX0duOrq4MlVqj/4oeNJ 23 | oAI1g9fXyIOwmNbfxc0K5y4FejD5z/cyKQ4MKKtIJDEHBfJmDU/r9LyAzpaQQefq 24 | M5Fq0yMPtzyx+nKT9eYQOPw4ezVKS5/jDfzWiSUCgYBDvWIggy1uNh8z0A5Z65d5 25 | 7gWqWy0gVccSKF24NUjtDmrLBhJluk8vYmF64fonahXAH2uZFWx/Svml64oK5JmJ 26 | gVodNrUCZ2Qq5/nt8NStgFHx5pylfEibLEuBpApbaHH8QA7s/aXFWhH47AGrDw5W 27 | u8AS9yo8G1A/XmUPUhccNg== 28 | -----END PRIVATE KEY----- 29 | 30 | -------------------------------------------------------------------------------- /test/test_cancellation.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | let () = Dynamic.set_root Backtrace.elide true 5 | let harness = lazy (Harness.create ()) 6 | 7 | let%expect_test ("Demonstrate that cancelling running queries is possible" 8 | [@tags "disabled"]) 9 | = 10 | let harness = force harness in 11 | let user = "postgres" in 12 | let connect () = 13 | Postgres_async.connect 14 | () 15 | ~server:(Harness.where_to_connect harness) 16 | ~user 17 | ~database:"postgres" 18 | >>| Or_error.ok_exn 19 | in 20 | let%bind connection = connect () in 21 | let%bind activity_connection = connect () in 22 | let query_to_cancel = 23 | Postgres_async.query 24 | connection 25 | "select 1 from pg_sleep(100)" 26 | ~handle_row:(fun ~column_names ~values -> 27 | raise_s 28 | [%message 29 | "Unexpectedly produced rows" 30 | (column_names : string iarray) 31 | (values : string option iarray)]) 32 | in 33 | let%bind () = 34 | Deferred.repeat_until_finished () (fun () -> 35 | let count = ref 0 in 36 | let%map () = 37 | Postgres_async.query 38 | activity_connection 39 | [%string 40 | "select * from pg_stat_activity where state = 'active' and query like \ 41 | '%pg_sleep%'"] 42 | ~handle_row:(fun ~column_names:_ ~values:_ -> incr count) 43 | >>| Or_error.ok_exn 44 | in 45 | match Int.equal !count 2 with 46 | | true -> `Finished () 47 | | false -> `Repeat ()) 48 | in 49 | let%bind cancel_result = Postgres_async.Private.pq_cancel connection in 50 | let%bind query_to_cancel = Clock_ns.with_timeout Time_ns.Span.second query_to_cancel in 51 | let%bind begin_result = Postgres_async.query_expect_no_data connection "BEGIN" in 52 | let%bind close_result = Postgres_async.close connection in 53 | print_s 54 | [%message 55 | (query_to_cancel : unit Or_error.t Clock_ns.Or_timeout.t) 56 | (cancel_result : unit Or_error.t) 57 | (begin_result : unit Or_error.t) 58 | (close_result : unit Or_error.t)]; 59 | [%expect 60 | {| 61 | ((query_to_cancel 62 | (Result 63 | (Error 64 | ("Error during query execution (despite parsing ok)" 65 | ((Code 57014) (Message "canceling statement due to user request")))))) 66 | (cancel_result (Ok ())) (begin_result (Ok ())) (close_result (Ok ()))) 67 | |}]; 68 | [%expect {| |}]; 69 | return () 70 | ;; 71 | 72 | let%expect_test "Parse full backend key range" = 73 | let secret_sent = Int32.max_value |> Int32.to_int_exn in 74 | let pid_sent = 9999 in 75 | print_s [%message (pid_sent : int) (secret_sent : int)]; 76 | let buf = Iobuf.create ~len:8 in 77 | Iobuf.Fill.int32_be_trunc buf pid_sent; 78 | Iobuf.Fill.int32_be_trunc buf secret_sent; 79 | Iobuf.flip_lo buf; 80 | let backend_key_data = 81 | Postgres_async.Private.Protocol.Backend.BackendKeyData.consume buf |> Or_error.ok_exn 82 | in 83 | print_s [%sexp (backend_key_data : Postgres_async.Private.Protocol.Types.backend_key)]; 84 | [%expect 85 | {| 86 | ((pid_sent 9999) (secret_sent 2147483647)) 87 | ((pid 9999) (secret 2147483647)) 88 | |}]; 89 | Deferred.unit 90 | ;; 91 | -------------------------------------------------------------------------------- /test/test_cancellation.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_connect.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | let () = Dynamic.set_root Backtrace.elide true 5 | let harness = lazy (Harness.create ()) 6 | 7 | let with_dummy_server func = 8 | let got_connection = Ivar.create () in 9 | let%bind server = 10 | Tcp.Server.create 11 | ~on_handler_error:`Raise 12 | Tcp.Where_to_listen.of_port_chosen_by_os 13 | (fun _ reader _ -> 14 | let%bind _ = Reader.peek reader ~len:1 in 15 | Ivar.fill_exn got_connection (); 16 | Reader.drain reader) 17 | in 18 | let got_connection = Ivar.read got_connection in 19 | let addr = 20 | Host_and_port.create ~host:"127.0.0.1" ~port:(Tcp.Server.listening_on server) 21 | in 22 | let finally () = Tcp.Server.close server in 23 | Monitor.protect ~run:`Now ~rest:`Raise ~finally (fun () -> func addr ~got_connection) 24 | ;; 25 | 26 | let%expect_test "interrupt - TCP" = 27 | let%bind () = 28 | with_dummy_server (fun addr ~got_connection:_ -> 29 | (* This is technically racey, because the connect call could in principle complete 30 | before the [choice] picks the interrupt. But it doesn't. *) 31 | match%bind 32 | Postgres_async.with_connection 33 | ~server:(Tcp.Where_to_connect.of_host_and_port addr) 34 | ~interrupt:(return ()) 35 | ~database:"dummy" 36 | ~on_handler_exception:`Raise 37 | (fun _ -> failwith "connection succeeded!?") 38 | with 39 | | Ok (_ : Nothing.t) -> . 40 | | Error err -> 41 | print_s [%sexp (err : Error.t)]; 42 | [%expect 43 | {| 44 | (monitor.ml.Error ("connection attempt aborted" 127.0.0.1:PORT) 45 | ("" "Caught by monitor try_with_or_error")) 46 | |}]; 47 | return ()) 48 | in 49 | let%bind () = 50 | with_dummy_server (fun addr ~got_connection -> 51 | match%bind 52 | Postgres_async.with_connection 53 | ~server:(Tcp.Where_to_connect.of_host_and_port addr) 54 | ~interrupt:got_connection 55 | ~database:"dummy" 56 | ~on_handler_exception:`Raise 57 | (fun _ -> failwith "connection succeeded!?") 58 | with 59 | | Ok (_ : Nothing.t) -> . 60 | | Error err -> 61 | print_s [%sexp (err : Error.t)]; 62 | [%expect {| "login interrupted" |}]; 63 | return ()) 64 | in 65 | return () 66 | ;; 67 | 68 | let%expect_test "interrupt - SSL" = 69 | let%bind () = 70 | with_dummy_server (fun addr ~got_connection:_ -> 71 | (* This is technically racey, because the connect call could in principle complete 72 | before the [choice] picks the interrupt. But it doesn't. *) 73 | match%bind 74 | Postgres_async.with_connection 75 | ~server:(Tcp.Where_to_connect.of_host_and_port addr) 76 | ~ssl_mode:Prefer 77 | ~interrupt:(return ()) 78 | ~database:"dummy" 79 | ~on_handler_exception:`Raise 80 | (fun _ -> failwith "connection succeeded!?") 81 | with 82 | | Ok (_ : Nothing.t) -> . 83 | | Error err -> 84 | print_s [%sexp (err : Error.t)]; 85 | [%expect 86 | {| 87 | (monitor.ml.Error ("connection attempt aborted" 127.0.0.1:PORT) 88 | ("" "Caught by monitor try_with_or_error")) 89 | |}]; 90 | return ()) 91 | in 92 | let%bind () = 93 | with_dummy_server (fun addr ~got_connection -> 94 | match%bind 95 | Postgres_async.with_connection 96 | ~server:(Tcp.Where_to_connect.of_host_and_port addr) 97 | ~ssl_mode:Prefer 98 | ~interrupt:got_connection 99 | ~database:"dummy" 100 | ~on_handler_exception:`Raise 101 | (fun _ -> failwith "connection succeeded!?") 102 | with 103 | | Ok (_ : Nothing.t) -> . 104 | | Error err -> 105 | print_s [%sexp (err : Error.t)]; 106 | [%expect {| "ssl negotiation interrupted" |}]; 107 | return ()) 108 | in 109 | return () 110 | ;; 111 | 112 | let try_login ?(user = "postgres") ?password ?(database = "postgres") harness = 113 | let get_user postgres = 114 | let u_d = Set_once.create () in 115 | let%bind result = 116 | Postgres_async.query 117 | postgres 118 | "SELECT CURRENT_USER, current_database()" 119 | ~handle_row:(fun ~column_names:_ ~values -> 120 | match Iarray.to_array values with 121 | | [| Some u; Some d |] -> Set_once.set_exn u_d (u, d) 122 | | _ -> failwith "bad query response") 123 | in 124 | Or_error.ok_exn result; 125 | return (Set_once.get_exn u_d) 126 | in 127 | let%bind result = 128 | Postgres_async.with_connection 129 | ~server:(Harness.where_to_connect harness) 130 | ~user 131 | ?password 132 | ~database 133 | ~on_handler_exception:`Raise 134 | get_user 135 | in 136 | (* we can't print any more than "login failed" because the error messages are not stable 137 | wrt. postgres versions. *) 138 | (match result with 139 | | Ok (u, d) -> printf "OK; user:%s database:%s\n" u d 140 | | Error _ -> printf "Login failed\n"); 141 | return () 142 | ;; 143 | 144 | let%expect_test "trust auth" = 145 | let harness = force harness in 146 | let%bind () = try_login harness in 147 | [%expect {| OK; user:postgres database:postgres |}]; 148 | let%bind () = try_login harness ~password:"nonsense" in 149 | [%expect {| OK; user:postgres database:postgres |}]; 150 | return () 151 | ;; 152 | 153 | let%expect_test "password authentication" = 154 | let harness = force harness in 155 | let%bind () = 156 | Harness.with_connection_exn harness ~database:"postgres" (fun postgres -> 157 | let q str = 158 | let%bind r = Postgres_async.query_expect_no_data postgres str in 159 | Or_error.ok_exn r; 160 | return () 161 | in 162 | let%bind () = q "CREATE ROLE role_password_login" in 163 | let%bind () = q "CREATE ROLE auth_test_1 LOGIN PASSWORD 'test-password'" in 164 | let%bind () = q "GRANT role_password_login TO auth_test_1" in 165 | return ()) 166 | in 167 | [%expect {| |}]; 168 | let%bind () = try_login harness ~user:"auth_test_1" ~password:"test-password" in 169 | [%expect {| OK; user:auth_test_1 database:postgres |}]; 170 | let%bind () = try_login harness ~user:"auth_test_1" ~password:"bad" in 171 | [%expect {| Login failed |}]; 172 | return () 173 | ;; 174 | 175 | let%expect_test "unix sockets & tcp" = 176 | let harness = force harness in 177 | let%bind result = 178 | Postgres_async.with_connection 179 | ~server:(Tcp.Where_to_connect.of_file (Harness.unix_socket_path harness)) 180 | ~user:"postgres" 181 | ~database:"postgres" 182 | ~on_handler_exception:`Raise 183 | (fun _ -> return ()) 184 | in 185 | Or_error.ok_exn result; 186 | let%bind result = 187 | let hap = Host_and_port.create ~host:"127.0.0.1" ~port:(Harness.port harness) in 188 | Postgres_async.with_connection 189 | ~server:(Tcp.Where_to_connect.of_host_and_port hap) 190 | ~user:"postgres" 191 | ~database:"postgres" 192 | ~on_handler_exception:`Raise 193 | (fun _ -> return ()) 194 | in 195 | Or_error.ok_exn result; 196 | return () 197 | ;; 198 | 199 | let%expect_test "authentication method we don't support" = 200 | let%bind tcp_server = 201 | let handle_client _sock _reader writer = 202 | Writer.write writer "R\x00\x00\x00\x08\x00\x00\x00\x06"; 203 | Deferred.never () 204 | in 205 | Tcp.Server.create 206 | ~on_handler_error:`Raise 207 | Tcp.Where_to_listen.of_port_chosen_by_os 208 | handle_client 209 | in 210 | let where_to_connect = 211 | let port = Tcp.Server.listening_on tcp_server in 212 | Tcp.Where_to_connect.of_host_and_port (Host_and_port.create ~host:"localhost" ~port) 213 | in 214 | let%bind result = 215 | Postgres_async.with_connection 216 | ~server:where_to_connect 217 | ~user:"postgres" 218 | ~database:"postgres" 219 | ~on_handler_exception:`Raise 220 | (fun _ -> return ()) 221 | in 222 | print_s [%sexp (result : _ Or_error.t)]; 223 | [%expect {| (Error "Server wants unimplemented auth subtype: SCMCredential") |}]; 224 | return () 225 | ;; 226 | 227 | let%expect_test "connection refused" = 228 | (* bind, but don't listen or accept. *) 229 | let socket = Core_unix.socket ~domain:PF_INET ~kind:SOCK_STREAM ~protocol:0 () in 230 | Core_unix.bind socket ~addr:(ADDR_INET (Unix.Inet_addr.of_string "0.0.0.0", 0)); 231 | let where_to_connect = 232 | match Core_unix.getsockname socket with 233 | | ADDR_UNIX _ -> assert false 234 | | ADDR_INET (_, port) -> 235 | Tcp.Where_to_connect.of_host_and_port (Host_and_port.create ~host:"localhost" ~port) 236 | in 237 | let%bind result = 238 | Postgres_async.with_connection 239 | ~server:where_to_connect 240 | ~user:"postgres" 241 | ~database:"postgres" 242 | ~on_handler_exception:`Raise 243 | (fun _ -> return ()) 244 | in 245 | Core_unix.close socket; 246 | print_s [%sexp (result : _ Or_error.t)]; 247 | [%expect 248 | {| 249 | (Error 250 | (monitor.ml.Error 251 | (Unix.Unix_error "Connection refused" connect 127.0.0.1:PORT) 252 | ("" "Caught by monitor Tcp.close_sock_on_error"))) 253 | |}]; 254 | return () 255 | ;; 256 | 257 | let%expect_test "graceful close" = 258 | let harness = force harness in 259 | let%bind connection = 260 | Postgres_async.connect 261 | () 262 | ~server:(Harness.where_to_connect harness) 263 | ~user:"postgres" 264 | ~database:"postgres" 265 | in 266 | let connection = Or_error.ok_exn connection in 267 | let%bind result = Postgres_async.close connection in 268 | print_s [%sexp (result : unit Or_error.t)]; 269 | [%expect {| (Ok ()) |}]; 270 | return () 271 | ;; 272 | -------------------------------------------------------------------------------- /test/test_connect.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_copy_in.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Async 3 | 4 | let with_connection_exn = 5 | let database = "test_copy_in" in 6 | let harness = 7 | lazy 8 | (let h = Harness.create () in 9 | Harness.create_database h database; 10 | h) 11 | in 12 | fun func -> Harness.with_connection_exn (force harness) ~database func 13 | ;; 14 | 15 | let create_table postgres name columns = 16 | let%bind result = 17 | Postgres_async.query_expect_no_data 18 | postgres 19 | (sprintf "CREATE TEMPORARY TABLE %s ( %s )" name (String.concat ~sep:"," columns)) 20 | in 21 | Or_error.ok_exn result; 22 | return () 23 | ;; 24 | 25 | let print_table postgres table = 26 | let%bind result = 27 | Postgres_async.query 28 | postgres 29 | (sprintf "SELECT * FROM %s ORDER BY y" table) 30 | ~handle_row:(fun ~column_names:_ ~values -> 31 | print_s [%sexp (values : string option iarray)]) 32 | in 33 | Or_error.ok_exn result; 34 | return () 35 | ;; 36 | 37 | let%expect_test "copy_in_rows" = 38 | with_connection_exn (fun postgres -> 39 | let%bind () = create_table postgres "x" [ "y integer primary key"; "z text" ] in 40 | [%expect {| |}]; 41 | let%bind result = 42 | let rows = 43 | Queue.of_list 44 | [ [| Some "one"; Some "1" |] 45 | ; [| None; Some "2" |] 46 | ; [| Some "three"; Some "3" |] 47 | ] 48 | in 49 | Postgres_async.copy_in_rows 50 | postgres 51 | ~table_name:"x" 52 | ~column_names:[ "z"; "y" ] 53 | ~feed_data:(fun () -> 54 | match Queue.dequeue rows with 55 | | None -> Finished 56 | | Some c -> Data c) 57 | in 58 | Or_error.ok_exn result; 59 | [%expect {| |}]; 60 | let%bind () = print_table postgres "x" in 61 | [%expect 62 | {| 63 | ((1) (one)) 64 | ((2) ()) 65 | ((3) (three)) 66 | |}]; 67 | return ()) 68 | ;; 69 | 70 | let%expect_test "copy_in_rows, schema" = 71 | with_connection_exn (fun postgres -> 72 | let%bind () = 73 | create_table postgres {|pg_temp."test.table"|} [ "y integer primary key"; "z text" ] 74 | in 75 | [%expect {| |}]; 76 | let%bind result = 77 | let rows = 78 | Queue.of_list 79 | [ [| Some "one"; Some "1" |] 80 | ; [| None; Some "2" |] 81 | ; [| Some "three"; Some "3" |] 82 | ] 83 | in 84 | Postgres_async.copy_in_rows 85 | postgres 86 | ~schema_name:"pg_temp" 87 | ~table_name:{|"test.table"|} 88 | ~column_names:[ "z"; "y" ] 89 | ~feed_data:(fun () -> 90 | match Queue.dequeue rows with 91 | | None -> Finished 92 | | Some c -> Data c) 93 | in 94 | Or_error.ok_exn result; 95 | [%expect {| |}]; 96 | let%bind () = print_table postgres {|pg_temp."test.table"|} in 97 | [%expect 98 | {| 99 | ((1) (one)) 100 | ((2) ()) 101 | ((3) (three)) 102 | |}]; 103 | return ()) 104 | ;; 105 | 106 | let%expect_test "copy_in_rows: nasty characters" = 107 | with_connection_exn (fun postgres -> 108 | let%bind () = 109 | create_table postgres "x" [ "y integer primary key"; "z text"; "w bytea" ] 110 | in 111 | [%expect {| |}]; 112 | let%bind result = 113 | let rows = 114 | Queue.of_list 115 | [ [| Some "1"; Some "\n"; None |] 116 | ; [| Some "2"; Some "\\N"; None |] 117 | ; [| Some "3"; Some "\t"; None |] 118 | ; [| Some "4"; Some "\\t"; None |] 119 | ; [| Some "5"; Some ","; None |] 120 | ; [| Some "6"; Some ""; None |] 121 | ; [| Some "7"; Some "\\"; None |] 122 | ; [| Some "8"; Some "\\x61"; None |] 123 | ; [| Some "9"; Some ""; None |] 124 | ; [| Some "10"; Some "\x00"; None |] 125 | ; [| Some "11"; None; Some "asdf" |] 126 | ; [| Some "12"; None; Some "\n" |] 127 | ; [| Some "13"; None; Some "\\x00" |] 128 | ; [| Some "14"; None; Some "\\x61" |] 129 | ] 130 | in 131 | Postgres_async.copy_in_rows 132 | postgres 133 | ~table_name:"x" 134 | ~column_names:[ "y"; "z"; "w" ] 135 | ~feed_data:(fun () -> 136 | match Queue.dequeue rows with 137 | | None -> Finished 138 | | Some c -> Data c) 139 | in 140 | Or_error.ok_exn result; 141 | [%expect {| |}]; 142 | let%bind () = print_table postgres "x" in 143 | [%expect 144 | {| 145 | ((1) ("\n") ()) 146 | ((2) ("\\N") ()) 147 | ((3) ("\t") ()) 148 | ((4) ("\\t") ()) 149 | ((5) (,) ()) 150 | ((6) ("") ()) 151 | ((7) ("\\") ()) 152 | ((8) ("\\x61") ()) 153 | ((9) ("") ()) 154 | ((10) ("") ()) 155 | ((11) () ("\\x61736466")) 156 | ((12) () ("\\x0a")) 157 | ((13) () ("\\x00")) 158 | ((14) () ("\\x61")) 159 | |}]; 160 | return ()) 161 | ;; 162 | 163 | let%expect_test "copy_in_rows: nasty column names" = 164 | with_connection_exn (fun postgres -> 165 | let%bind result = 166 | (* year is a keyword and must be quoted *) 167 | Postgres_async.query_expect_no_data 168 | postgres 169 | {| 170 | CREATE TEMPORARY TABLE "table-name " ( 171 | k integer primary key, 172 | "y space" text, 173 | "z""quote" text, 174 | "year" text, 175 | LOWERCASE1 text, 176 | "UPPERCASE2" text 177 | ) 178 | |} 179 | in 180 | Or_error.ok_exn result; 181 | [%expect {| |}]; 182 | let%bind result = 183 | let sent_row = ref false in 184 | Postgres_async.copy_in_rows 185 | postgres 186 | ~table_name:{|"table-name "|} 187 | ~column_names: 188 | [ "k"; {|"y space"|}; {|"z""quote"|}; "year"; "lowercase1"; {|"UPPERCASE2"|} ] 189 | ~feed_data:(fun () -> 190 | match !sent_row with 191 | | true -> Finished 192 | | false -> 193 | sent_row := true; 194 | Data [| Some "1"; Some "A"; Some "B"; Some "C"; Some "D"; Some "E" |]) 195 | in 196 | Or_error.ok_exn result; 197 | [%expect {| |}]; 198 | let%bind result = 199 | Postgres_async.query 200 | postgres 201 | {| SELECT * FROM "table-name " ORDER BY k |} 202 | ~handle_row:(fun ~column_names ~values -> 203 | Iarray.iter (Iarray.zip_exn column_names values) ~f:(fun (k, v) -> 204 | print_s [%sexp (k : string), (v : string option)])) 205 | in 206 | Or_error.ok_exn result; 207 | [%expect 208 | {| 209 | (k (1)) 210 | ("y space" (A)) 211 | ("z\"quote" (B)) 212 | (year (C)) 213 | (lowercase1 (D)) 214 | (UPPERCASE2 (E)) 215 | |}]; 216 | return ()) 217 | ;; 218 | 219 | let%expect_test "copy_in_rows: lots of data" = 220 | with_connection_exn (fun postgres -> 221 | let%bind () = create_table postgres "x" [ "y integer primary key"; "z text" ] in 222 | [%expect {| |}]; 223 | let%bind result = 224 | let counter = ref 0 in 225 | let sleeps = ref 0 in 226 | let one_kb = String.init 1024 ~f:(const 'a') in 227 | Postgres_async.copy_in_rows 228 | postgres 229 | ~table_name:"x" 230 | ~column_names:[ "y"; "z" ] 231 | ~feed_data:(fun () -> 232 | match !counter >= 8192 with 233 | | true -> Finished 234 | | false -> 235 | (match !counter / 256 < !sleeps with 236 | | true -> 237 | incr sleeps; 238 | Wait ((force Utils.do_an_epoll) ()) 239 | | false -> 240 | incr counter; 241 | Data [| Some (Int.to_string !counter); Some one_kb |])) 242 | in 243 | Or_error.ok_exn result; 244 | [%expect {| |}]; 245 | let%bind result = 246 | Postgres_async.query 247 | postgres 248 | {| SELECT COUNT(*), MIN(y), MAX(y), SUM(y), MIN(LENGTH(z)) FROM x |} 249 | ~handle_row:(fun ~column_names:_ ~values -> 250 | let values = Iarray.map ~f:(fun x -> Option.value_exn x) values in 251 | print_s [%sexp (values : string iarray)]) 252 | in 253 | Or_error.ok_exn result; 254 | [%expect {| (8192 1 8192 33558528 1024) |}]; 255 | print_s [%sexp (8192 * 8193 / 2 : int)]; 256 | [%expect {| 33558528 |}]; 257 | return ()) 258 | ;; 259 | 260 | let%expect_test "raw: weird chunking" = 261 | with_connection_exn (fun postgres -> 262 | let%bind () = 263 | create_table postgres "x" [ "y integer primary key"; "z text not null" ] 264 | in 265 | [%expect {| |}]; 266 | (* weird chunking is fine, it doesn't need to correspond to rows: *) 267 | let%bind result = 268 | let chunks = Queue.of_list [ "1\tone\n"; "2\t"; "two"; "\n" ] in 269 | Postgres_async.copy_in_raw postgres "COPY x (y, z) FROM STDIN" ~feed_data:(fun () -> 270 | match Queue.dequeue chunks with 271 | | None -> Finished 272 | | Some c -> Data c) 273 | in 274 | Or_error.ok_exn result; 275 | [%expect {| |}]; 276 | let%bind () = print_table postgres "x" in 277 | [%expect 278 | {| 279 | ((1) (one)) 280 | ((2) (two)) 281 | |}]; 282 | return ()) 283 | ;; 284 | 285 | let%expect_test "aborting" = 286 | with_connection_exn (fun postgres -> 287 | let%bind () = 288 | create_table postgres "x" [ "y integer primary key"; "z text not null" ] 289 | in 290 | [%expect {| |}]; 291 | let%bind result = 292 | let count = ref 0 in 293 | Postgres_async.copy_in_raw postgres "COPY x (y, z) FROM STDIN" ~feed_data:(fun () -> 294 | incr count; 295 | match !count with 296 | | 1 -> Data "1\tone\n" 297 | | 2 -> Wait (Scheduler.yield_until_no_jobs_remain ()) 298 | | 3 -> Data "2\ttwo\n" 299 | | 4 -> Abort { reason = "user reason" } 300 | | _ -> assert false) 301 | in 302 | (match result with 303 | | Ok () -> failwith "succeeded!?" 304 | | Error err -> 305 | let err = Utils.delete_unstable_bits_of_error [%sexp (err : Error.t)] in 306 | print_s err); 307 | (* 57014: query_cancelled. *) 308 | [%expect {| ((query "COPY x (y, z) FROM STDIN") ((Code 57014))) |}]; 309 | let%bind () = print_table postgres "x" in 310 | [%expect {| |}]; 311 | return ()) 312 | ;; 313 | 314 | let%expect_test "copy_in_rows with schema prefix" = 315 | with_connection_exn (fun postgres -> 316 | let%bind () = 317 | Postgres_async.query_expect_no_data postgres "CREATE SCHEMA my_schema" 318 | >>| Or_error.ok_exn 319 | in 320 | let%bind () = 321 | Postgres_async.query_expect_no_data postgres "CREATE TABLE my_schema.x (y integer)" 322 | >>| Or_error.ok_exn 323 | in 324 | [%expect {| |}]; 325 | let%bind result = 326 | let rows = Queue.of_list [ [| Some "1" |]; [| Some "2" |] ] in 327 | Postgres_async.copy_in_rows 328 | postgres 329 | ~table_name:"my_schema.x" 330 | ~column_names:[ "y" ] 331 | ~feed_data:(fun () -> 332 | match Queue.dequeue rows with 333 | | None -> Finished 334 | | Some c -> Data c) 335 | in 336 | Or_error.ok_exn result; 337 | [%expect {| |}]; 338 | let%bind () = 339 | Postgres_async.query 340 | postgres 341 | "SELECT * FROM my_schema.x ORDER BY y" 342 | ~handle_row:(fun ~column_names:_ ~values -> 343 | print_s [%sexp (values : string option iarray)]) 344 | >>| Or_error.ok_exn 345 | in 346 | [%expect 347 | {| 348 | ((1)) 349 | ((2)) 350 | |}]; 351 | return ()) 352 | ;; 353 | 354 | let%expect_test "error handling" = 355 | with_connection_exn (fun postgres -> 356 | let%bind () = 357 | create_table postgres "x" [ "y integer primary key"; "z text not null" ] 358 | in 359 | [%expect {| |}]; 360 | let%bind result = 361 | let count = ref 0 in 362 | Postgres_async.copy_in_raw postgres "COPY x (y, z) FROM STDIN" ~feed_data:(fun () -> 363 | incr count; 364 | (* Some of the error handing happens after we have done waiting for the Wait 365 | message, so lets make sure we generate plenty of them and intersperse an error 366 | in between *) 367 | match !count with 368 | | 900 -> Data "not-a-number\ttwo\n" 369 | | n -> 370 | if n mod 2 = 0 371 | then Wait (Scheduler.yield_until_no_jobs_remain ()) 372 | else if n < 1000 373 | then Data (Int.to_string n ^ "\ttrue\n") 374 | else Finished) 375 | in 376 | (match result with 377 | | Ok () -> failwith "succeeded!?" 378 | | Error err -> 379 | let err = Utils.delete_unstable_bits_of_error [%sexp (err : Error.t)] in 380 | print_s err); 381 | (* 22P02: invalid syntax for integer. *) 382 | [%expect {| ((query "COPY x (y, z) FROM STDIN") ((Code 22P02))) |}]; 383 | let%bind () = print_table postgres "x" in 384 | [%expect {| |}]; 385 | return ()) 386 | ;; 387 | -------------------------------------------------------------------------------- /test/test_copy_in.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_copy_out.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | 4 | let setup = 5 | [ "CREATE EXTENSION IF NOT EXISTS hstore" 6 | ; "BEGIN" 7 | ; {| 8 | CREATE TABLE example_datatypes ( 9 | "null" text, 10 | "bit" bit, 11 | "bit varying" bit varying, 12 | "char" char, 13 | "int2" int2, 14 | "int4" int4, 15 | "int8" int8, 16 | "float4" float4, 17 | "float8" float8, 18 | "uuid" uuid, 19 | "numeric" numeric, 20 | "array" integer[], 21 | "cidr" cidr, 22 | "macaddr" macaddr, 23 | "jsonb" jsonb, 24 | "hstore" hstore, 25 | "text" text, 26 | "bytea" bytea, 27 | "varchar" varchar(20) 28 | )|} 29 | ; {| 30 | INSERT INTO example_datatypes ( 31 | "null", "bit", "bit varying", "char", "int2", "int4", "int8", "float4", "float8", 32 | "uuid", "numeric", "array", "cidr", "macaddr", "jsonb", "hstore", 33 | "text", "bytea", "varchar" 34 | ) VALUES ( 35 | NULL, 36 | B'1', 37 | B'11010001000000010000000000000001000000000000000000000000000000010', 38 | 'A', 39 | 32767, 40 | 2147483647, 41 | 9223372036854775807, 42 | 2.7182817, 43 | 3.14159265358979::float8, 44 | '123e4567-e89b-12d3-a456-426614174000', 45 | 123.45, 46 | ARRAY[1, 2, 3], 47 | '192.168.100.128/25', 48 | '08:00:2b:01:02:03', 49 | '{"key": "value", "nested": {"array": [1, 2, 3]}}'::jsonb, 50 | '"key1"=>"value1", "key2"=>"value2"'::hstore, 51 | 'This is a text field', 52 | E'\\x00DEADBEEF00'::bytea, 53 | 'VARCHAR example' 54 | ) 55 | |} 56 | ; "COMMIT" 57 | ] 58 | ;; 59 | 60 | let with_connection_exn = 61 | let database = "test_copy_out" in 62 | let harness = 63 | lazy 64 | (let h = Harness.create () in 65 | Harness.create_database h database; 66 | h) 67 | in 68 | fun func -> Harness.with_connection_exn (force harness) ~database func 69 | ;; 70 | 71 | let copy_out postgres query_string = 72 | let messages = Queue.create () in 73 | let%map command_complete = 74 | Postgres_async.Private.iter_copy_out postgres ~query_string ~f:(fun iobuf -> 75 | Queue.enqueue messages (Iobuf.Consume.stringo iobuf); 76 | return ()) 77 | >>| Postgres_async.Or_pgasync_error.ok_exn 78 | in 79 | print_s [%sexp (command_complete : Postgres_async.Command_complete.t)]; 80 | messages 81 | ;; 82 | 83 | let%expect_test "setup" = 84 | with_connection_exn (fun postgres -> 85 | Deferred.List.iter ~how:`Sequential setup ~f:(fun command -> 86 | Postgres_async.query_expect_no_data postgres command >>| ok_exn)) 87 | ;; 88 | 89 | let%expect_test "empty" = 90 | with_connection_exn (fun postgres -> 91 | let%bind messages = copy_out postgres "COPY (SELECT) TO STDOUT" in 92 | [%expect {| ((tag "COPY 1") (rows ())) |}]; 93 | Queue.iter messages ~f:(fun m -> print_endline (String.Hexdump.to_string_hum m)); 94 | [%expect {| 00000000 0a |.| |}]; 95 | let%bind messages = copy_out postgres "COPY (SELECT) TO STDOUT (FORMAT binary)" in 96 | [%expect {| ((tag "COPY 1") (rows ())) |}]; 97 | Queue.iter messages ~f:(fun m -> print_endline (String.Hexdump.to_string_hum m)); 98 | [%expect 99 | {| 100 | 00000000 50 47 43 4f 50 59 0a ff 0d 0a 00 00 00 00 00 00 |PGCOPY..........| 101 | 00000010 00 00 00 00 00 |.....| 102 | 00000000 ff ff |..| 103 | |}]; 104 | return ()) 105 | ;; 106 | 107 | let%expect_test "example" = 108 | with_connection_exn (fun postgres -> 109 | let%bind messages = copy_out postgres "COPY example_datatypes TO STDOUT" in 110 | [%expect {| ((tag "COPY 1") (rows ())) |}]; 111 | Queue.iter messages ~f:(fun m -> print_endline (String.Hexdump.to_string_hum m)); 112 | [%expect 113 | {xxx| 114 | 00000000 5c 4e 09 31 09 31 31 30 31 30 30 30 31 30 30 30 |\N.1.11010001000| 115 | 00000010 30 30 30 30 31 30 30 30 30 30 30 30 30 30 30 30 |0000100000000000| 116 | 00000020 30 30 30 30 31 30 30 30 30 30 30 30 30 30 30 30 |0000100000000000| 117 | 00000030 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 |0000000000000000| 118 | 00000040 30 30 30 30 31 30 09 41 09 33 32 37 36 37 09 32 |000010.A.32767.2| 119 | 00000050 31 34 37 34 38 33 36 34 37 09 39 32 32 33 33 37 |147483647.922337| 120 | 00000060 32 30 33 36 38 35 34 37 37 35 38 30 37 09 32 2e |2036854775807.2.| 121 | 00000070 37 31 38 32 38 31 37 09 33 2e 31 34 31 35 39 32 |7182817.3.141592| 122 | 00000080 36 35 33 35 38 39 37 39 09 31 32 33 65 34 35 36 |65358979.123e456| 123 | 00000090 37 2d 65 38 39 62 2d 31 32 64 33 2d 61 34 35 36 |7-e89b-12d3-a456| 124 | 000000a0 2d 34 32 36 36 31 34 31 37 34 30 30 30 09 31 32 |-426614174000.12| 125 | 000000b0 33 2e 34 35 09 7b 31 2c 32 2c 33 7d 09 31 39 32 |3.45.{1,2,3}.192| 126 | 000000c0 2e 31 36 38 2e 31 30 30 2e 31 32 38 2f 32 35 09 |.168.100.128/25.| 127 | 000000d0 30 38 3a 30 30 3a 32 62 3a 30 31 3a 30 32 3a 30 |08:00:2b:01:02:0| 128 | 000000e0 33 09 7b 22 6b 65 79 22 3a 20 22 76 61 6c 75 65 |3.{"key": "value| 129 | 000000f0 22 2c 20 22 6e 65 73 74 65 64 22 3a 20 7b 22 61 |", "nested": {"a| 130 | 00000100 72 72 61 79 22 3a 20 5b 31 2c 20 32 2c 20 33 5d |rray": [1, 2, 3]| 131 | 00000110 7d 7d 09 22 6b 65 79 31 22 3d 3e 22 76 61 6c 75 |}}."key1"=>"valu| 132 | 00000120 65 31 22 2c 20 22 6b 65 79 32 22 3d 3e 22 76 61 |e1", "key2"=>"va| 133 | 00000130 6c 75 65 32 22 09 54 68 69 73 20 69 73 20 61 20 |lue2".This is a | 134 | 00000140 74 65 78 74 20 66 69 65 6c 64 09 5c 5c 78 30 30 |text field.\\x00| 135 | 00000150 64 65 61 64 62 65 65 66 30 30 09 56 41 52 43 48 |deadbeef00.VARCH| 136 | 00000160 41 52 20 65 78 61 6d 70 6c 65 0a |AR example.| 137 | |xxx}]; 138 | let%bind messages = 139 | copy_out postgres "COPY example_datatypes TO STDOUT (FORMAT binary)" 140 | in 141 | [%expect {| ((tag "COPY 1") (rows ())) |}]; 142 | Queue.iter messages ~f:(fun m -> print_endline (String.Hexdump.to_string_hum m)); 143 | [%expect 144 | {| 145 | 00000000 50 47 43 4f 50 59 0a ff 0d 0a 00 00 00 00 00 00 |PGCOPY..........| 146 | 00000010 00 00 00 00 13 ff ff ff ff 00 00 00 05 00 00 00 |................| 147 | 00000020 01 80 00 00 00 0d 00 00 00 41 d1 01 00 01 00 00 |.........A......| 148 | 00000030 00 01 00 00 00 00 01 41 00 00 00 02 7f ff 00 00 |.......A........| 149 | 00000040 00 04 7f ff ff ff 00 00 00 08 7f ff ff ff ff ff |................| 150 | 00000050 ff ff 00 00 00 04 40 2d f8 54 00 00 00 08 40 09 |......@-.T....@.| 151 | 00000060 21 fb 54 44 2d 11 00 00 00 10 12 3e 45 67 e8 9b |!.TD-......>Eg..| 152 | 00000070 12 d3 a4 56 42 66 14 17 40 00 00 00 00 0c 00 02 |...VBf..@.......| 153 | 00000080 00 00 00 00 00 02 00 7b 11 94 00 00 00 2c 00 00 |.......{.....,..| 154 | 00000090 00 01 00 00 00 00 00 00 00 17 00 00 00 03 00 00 |................| 155 | 000000a0 00 01 00 00 00 04 00 00 00 01 00 00 00 04 00 00 |................| 156 | 000000b0 00 02 00 00 00 04 00 00 00 03 00 00 00 08 02 19 |................| 157 | 000000c0 01 04 c0 a8 64 80 00 00 00 06 08 00 2b 01 02 03 |....d.......+...| 158 | 000000d0 00 00 00 31 01 7b 22 6b 65 79 22 3a 20 22 76 61 |...1.{"key": "va| 159 | 000000e0 6c 75 65 22 2c 20 22 6e 65 73 74 65 64 22 3a 20 |lue", "nested": | 160 | 000000f0 7b 22 61 72 72 61 79 22 3a 20 5b 31 2c 20 32 2c |{"array": [1, 2,| 161 | 00000100 20 33 5d 7d 7d 00 00 00 28 00 00 00 02 00 00 00 | 3]}}...(.......| 162 | 00000110 04 6b 65 79 31 00 00 00 06 76 61 6c 75 65 31 00 |.key1....value1.| 163 | 00000120 00 00 04 6b 65 79 32 00 00 00 06 76 61 6c 75 65 |...key2....value| 164 | 00000130 32 00 00 00 14 54 68 69 73 20 69 73 20 61 20 74 |2....This is a t| 165 | 00000140 65 78 74 20 66 69 65 6c 64 00 00 00 06 00 de ad |ext field.......| 166 | 00000150 be ef 00 00 00 00 0f 56 41 52 43 48 41 52 20 65 |.......VARCHAR e| 167 | 00000160 78 61 6d 70 6c 65 |xample| 168 | 00000000 ff ff |..| 169 | |}]; 170 | return ()) 171 | ;; 172 | 173 | let show_query postgres query = 174 | Postgres_async.query postgres query ~handle_row:(fun ~column_names ~values -> 175 | let values = Iarray.zip_exn column_names values in 176 | print_s [%sexp (values : (string * string option) iarray)]) 177 | >>| ok_exn 178 | ;; 179 | 180 | let%expect_test "round-trip" = 181 | with_connection_exn (fun postgres -> 182 | let%bind () = 183 | let%bind messages = copy_out postgres "COPY example_datatypes TO STDOUT" in 184 | [%expect {| ((tag "COPY 1") (rows ())) |}]; 185 | Postgres_async.copy_in_raw 186 | postgres 187 | "COPY example_datatypes FROM STDIN" 188 | ~feed_data:(fun () -> 189 | match Queue.dequeue messages with 190 | | None -> Finished 191 | | Some x -> Data x) 192 | >>| ok_exn 193 | in 194 | let%bind () = 195 | let%bind messages = 196 | copy_out postgres "COPY example_datatypes TO STDOUT (FORMAT binary)" 197 | in 198 | [%expect {| ((tag "COPY 2") (rows ())) |}]; 199 | Postgres_async.copy_in_raw 200 | postgres 201 | "COPY example_datatypes FROM STDIN (FORMAT binary)" 202 | ~feed_data:(fun () -> 203 | match Queue.dequeue messages with 204 | | None -> Finished 205 | | Some x -> Data x) 206 | >>| ok_exn 207 | in 208 | let%bind () = show_query postgres "SELECT count(*) FROM example_datatypes" in 209 | [%expect {| ((count (4))) |}]; 210 | let%bind () = 211 | show_query 212 | postgres 213 | "SELECT count(*) FROM (SELECT DISTINCT * from example_datatypes) as x" 214 | in 215 | (* we re-inserted the exact same row *) 216 | [%expect {| ((count (1))) |}]; 217 | return ()) 218 | ;; 219 | -------------------------------------------------------------------------------- /test/test_copy_out.mli: -------------------------------------------------------------------------------- 1 | (*_ This signature is deliberately empty. *) 2 | -------------------------------------------------------------------------------- /test/test_error_code.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Async 3 | open Expect_test_helpers_core 4 | 5 | let with_connection = 6 | let database = "test_error_code" in 7 | let harness = 8 | lazy 9 | (let h = Harness.create () in 10 | Harness.create_database h database; 11 | h) 12 | in 13 | fun func -> 14 | Postgres_async.Expert.with_connection 15 | ~user:"postgres" 16 | ~server:(Harness.where_to_connect (force harness)) 17 | ~database 18 | ~on_handler_exception:`Raise 19 | func 20 | ;; 21 | 22 | let print_or_pgasync_error or_pgasync_error = 23 | match or_pgasync_error with 24 | | Ok () -> print_s [%sexp Ok ()] 25 | | Error error -> 26 | let error_code = Postgres_async.Pgasync_error.postgres_error_code error in 27 | let severity = Postgres_async.Pgasync_error.postgres_field error Severity in 28 | let as_error = 29 | Postgres_async.Pgasync_error.to_error error 30 | |> [%sexp_of: Error.t] 31 | |> Utils.delete_unstable_bits_of_error 32 | in 33 | print_s 34 | [%message 35 | "Error" 36 | (error_code : string option) 37 | (severity : string option) 38 | (as_error : Sexp.t)] 39 | ;; 40 | 41 | let%expect_test "deadlock_detected has error_code=40P01" = 42 | let%bind connection_result = 43 | with_connection (fun postgres -> 44 | let%bind result = 45 | Postgres_async.Expert.query_expect_no_data 46 | postgres 47 | "DO $$ BEGIN RAISE deadlock_detected; END; $$" 48 | in 49 | print_or_pgasync_error result; 50 | [%expect 51 | {| 52 | (Error 53 | (error_code (40P01)) 54 | (severity (ERROR)) 55 | (as_error ( 56 | (query "DO $$ BEGIN RAISE deadlock_detected; END; $$") 57 | ("Postgres Server Error (state=Executing)" ((Code 40P01)))))) 58 | |}]; 59 | return ()) 60 | in 61 | print_or_pgasync_error connection_result; 62 | [%expect {| (Ok ()) |}]; 63 | return () 64 | ;; 65 | 66 | let%expect_test "error_code is erased from the result of query against a dead connection" = 67 | (* note, by the way, that in this test we're mising [Postgres_async.foo] functions (via 68 | [Utils.pg_backend_pid] and [Postgres_async.Expert.foo] functions, as one is allowed 69 | to do. *) 70 | let%bind result = 71 | with_connection (fun postgres -> 72 | let%bind backend_pid = Utils.pg_backend_pid postgres in 73 | let%bind result = 74 | with_connection (fun postgres2 -> 75 | let%bind result = 76 | Postgres_async.query 77 | postgres2 78 | "SELECT pg_terminate_backend($1)" 79 | ~parameters:[| Some backend_pid |] 80 | ~handle_row:(fun ~column_names:_ ~values:_ -> ()) 81 | in 82 | Or_error.ok_exn result; 83 | return ()) 84 | in 85 | Postgres_async.Or_pgasync_error.ok_exn result; 86 | (* The close-finished error does have the error code, since it was this error code 87 | that caused the problem. *) 88 | let%bind result = Postgres_async.Expert.close_finished postgres in 89 | print_or_pgasync_error result; 90 | [%expect 91 | {| 92 | (Error 93 | (error_code (57P01)) 94 | (severity (FATAL)) 95 | (as_error ( 96 | "ErrorResponse received asynchronously, assuming connection is dead" 97 | ((Severity FATAL) 98 | (Code 57P01))))) 99 | |}]; 100 | (* Attempting to issue new queries against the connection produces an error that 101 | specifies what the original error was, but does not claim the error code, since 102 | this error is not directly attributable to this query and it would be misleading 103 | to claim that it was the error code of this query (see comment in 104 | postgres_async.ml). *) 105 | let%bind result = Postgres_async.Expert.query_expect_no_data postgres "" in 106 | print_or_pgasync_error result; 107 | [%expect 108 | {| 109 | (Error 110 | (error_code (08003)) 111 | (severity ()) 112 | (as_error ( 113 | (query "") 114 | ("query issued against previously-failed connection" 115 | (original_error ( 116 | "ErrorResponse received asynchronously, assuming connection is dead" 117 | ((Severity FATAL) 118 | (Code 57P01)))))))) 119 | |}]; 120 | return ()) 121 | in 122 | print_or_pgasync_error result; 123 | [%expect {| (Ok ()) |}]; 124 | return () 125 | ;; 126 | 127 | let%expect_test "reporting syntax errors for short queries includes full query" = 128 | let%bind connection_result = 129 | with_connection (fun postgres -> 130 | let%bind result = 131 | Postgres_async.Expert.query_expect_no_data postgres "select foo, bar from baz" 132 | in 133 | print_or_pgasync_error result; 134 | [%expect 135 | {| 136 | (Error 137 | (error_code (42P01)) 138 | (severity (ERROR)) 139 | (as_error ( 140 | (query "select foo, bar from baz") 141 | ("Postgres Server Error (state=Parsing)" ((Code 42P01)))))) 142 | |}]; 143 | return ()) 144 | in 145 | print_or_pgasync_error connection_result; 146 | [%expect {| (Ok ()) |}]; 147 | return () 148 | ;; 149 | 150 | let%expect_test "reporting syntax errors for long queries shrinks context in a \ 151 | meaningful way" 152 | = 153 | let%bind connection_result = 154 | with_connection (fun postgres -> 155 | let nums = List.init 10000 ~f:Int.to_string |> String.concat ~sep:", " in 156 | let long_query = "select $1, " ^ nums ^ ", foobar, " ^ nums in 157 | let%bind result = 158 | Postgres_async.Expert.query_expect_no_data 159 | ~parameters:[| Some nums |] 160 | postgres 161 | long_query 162 | in 163 | print_or_pgasync_error result; 164 | [%expect 165 | {| 166 | (Error 167 | (error_code (42703)) 168 | (severity (ERROR)) 169 | (as_error ( 170 | ((query 171 | "... 29, 9830, 9831, 9832, 9833, 9834, 9835, 9836, 9837, 9838, 9839, 9840, 9841, 9842, 9843, 9844, 9845, 9846, 9847, 9848, 9849, 9850, 9851, 9852, 9853, 9854, 9855, 9856, 9857, 9858, 9859, 9860, 9861, 9862, 9863, 9864, 9865, 9866, 9867, 9868, 9869, 9870, 9871, 9872, 9873, 9874, 9875, 9876, 9877, 9878, 9879, 9880, 9881, 9882, 9883, 9884, 9885, 9886, 9887, 9888, 9889, 9890, 9891, 9892, 9893, 9894, 9895, 9896, 9897, 9898, 9899, 9900, 9901, 9902, 9903, 9904, 9905, 9906, 9907, 9908, 9909, 9910, 9911, 9912, 9913, 9914, 9915, 9916, 9917, 9918, 9919, 9920, 9921, 9922, 9923, 9924, 9925, 9926, 9927, 9928, 9929, 9930, 9931, 9932, 9933, 9934, 9935, 9936, 9937, 9938, 9939, 9940, 9941, 9942, 9943, 9944, 9945, 9946, 9947, 9948, 9949, 9950, 9951, 9952, 9953, 9954, 9955, 9956, 9957, 9958, 9959, 9960, 9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968, 9969, 9970, 9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978, 9979, 9980, 9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988, 9989, 9990, 9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999, foobar, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 2 ...") 172 | (parameters (( 173 | "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34...")))) 174 | ("Postgres Server Error (state=Parsing)" ((Code 42703)))))) 175 | |}]; 176 | return ()) 177 | in 178 | print_or_pgasync_error connection_result; 179 | [%expect {| (Ok ()) |}]; 180 | return () 181 | ;; 182 | -------------------------------------------------------------------------------- /test/test_error_code.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_notify.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | let harness = lazy (Harness.create ()) 5 | 6 | let query_exn postgres string = 7 | let%bind result = Postgres_async.query_expect_no_data postgres string in 8 | Or_error.ok_exn result; 9 | return () 10 | ;; 11 | 12 | let print_notifications ?saw_notification postgres ~channel = 13 | let%bind result = 14 | Postgres_async.listen_to_notifications postgres ~channel ~f:(fun ~pid:_ ~payload -> 15 | Option.iter saw_notification ~f:(fun bvar -> Bvar.broadcast bvar ()); 16 | print_s [%message "notification" ~channel ~payload]) 17 | in 18 | Or_error.ok_exn result; 19 | return () 20 | ;; 21 | 22 | let%expect_test "listen/notify" = 23 | let saw_notification = Bvar.create () in 24 | let print_notifications = print_notifications ~saw_notification in 25 | let harness = force harness in 26 | Harness.with_connection_exn harness ~database:"postgres" (fun postgres -> 27 | let%bind () = print_notifications postgres ~channel:"channel_one" in 28 | let%bind () = print_notifications postgres ~channel:"channel-2" in 29 | [%expect {| |}]; 30 | let sync1 = Bvar.wait saw_notification in 31 | let%bind () = query_exn postgres "NOTIFY channel_one" in 32 | let%bind () = sync1 in 33 | [%expect {| (notification (channel channel_one) (payload "")) |}]; 34 | let sync2 = Bvar.wait saw_notification in 35 | let%bind () = query_exn postgres "NOTIFY \"channel-2\", 'hello'" in 36 | let%bind () = sync2 in 37 | [%expect {| (notification (channel channel-2) (payload hello)) |}]; 38 | let sync3 = Bvar.wait saw_notification in 39 | let%bind () = 40 | Harness.with_connection_exn harness ~database:"postgres" (fun postgres2 -> 41 | query_exn postgres2 "NOTIFY channel_one, 'from-another-process'") 42 | in 43 | let%bind () = sync3 in 44 | [%expect {| (notification (channel channel_one) (payload from-another-process)) |}]; 45 | let%bind () = 46 | Harness.with_connection_exn harness ~database:"postgres" (fun postgres2 -> 47 | let sync4 = Bvar.wait saw_notification in 48 | let%bind () = query_exn postgres2 "NOTIFY \"channel-2\", 'm1'" in 49 | let%bind () = sync4 in 50 | let sync5 = Bvar.wait saw_notification in 51 | let%bind () = query_exn postgres2 "NOTIFY \"channel-2\", 'm2'" in 52 | let%bind () = sync5 in 53 | return ()) 54 | in 55 | [%expect 56 | {| 57 | (notification (channel channel-2) (payload m1)) 58 | (notification (channel channel-2) (payload m2)) 59 | |}]; 60 | return ()) 61 | ;; 62 | 63 | let with_other_log_outputs log outputs ~f = 64 | let before = Log.get_output log in 65 | Log.set_output log outputs; 66 | let finally () = 67 | Log.set_output log before; 68 | return () 69 | in 70 | Monitor.protect ~run:`Now ~rest:`Raise f ~finally 71 | ;; 72 | 73 | let%expect_test "multiple listeners" = 74 | Harness.with_connection_exn (force harness) ~database:"postgres" (fun postgres -> 75 | let b1 = Bvar.create () in 76 | let%bind () = print_notifications ~saw_notification:b1 postgres ~channel:"a" in 77 | let b2 = Bvar.create () in 78 | let%bind () = print_notifications ~saw_notification:b2 postgres ~channel:"a" in 79 | let b3 = Bvar.create () in 80 | let%bind () = print_notifications ~saw_notification:b3 postgres ~channel:"a" in 81 | let i1 = Bvar.wait b1 in 82 | let i2 = Bvar.wait b2 in 83 | let i3 = Bvar.wait b3 in 84 | let%bind () = query_exn postgres "NOTIFY a" in 85 | let%bind () = i1 in 86 | let%bind () = i2 in 87 | let%bind () = i3 in 88 | [%expect 89 | {| 90 | (notification (channel a) (payload "")) 91 | (notification (channel a) (payload "")) 92 | (notification (channel a) (payload "")) 93 | |}]; 94 | return ()) 95 | ;; 96 | 97 | let%expect_test "notify with no listeners" = 98 | let saw_log_message = Bvar.create () in 99 | let log_output = 100 | let flush () = return () in 101 | Log.Output.create ~flush (fun messages -> 102 | Bvar.broadcast saw_log_message (); 103 | Queue.iter messages ~f:(fun msg -> 104 | match Log.Message.raw_message msg with 105 | | `Sexp sexp -> print_s [%message "LOG" ~_:(sexp : Sexp.t)] 106 | | `String str -> printf "LOG: %s str\n" str); 107 | return ()) 108 | in 109 | with_other_log_outputs (force Log.Global.log) [ log_output ] ~f:(fun () -> 110 | Harness.with_connection_exn (force harness) ~database:"postgres" (fun postgres -> 111 | let sync = Bvar.wait saw_log_message in 112 | let%bind () = query_exn postgres "LISTEN channel_three" in 113 | let%bind () = query_exn postgres "NOTIFY channel_three" in 114 | let%bind () = sync in 115 | [%expect 116 | {| 117 | (LOG 118 | ("Postgres NotificationResponse on channel that no callbacks are listening to" 119 | (channel channel_three))) 120 | |}]; 121 | return ())) 122 | ;; 123 | -------------------------------------------------------------------------------- /test/test_notify.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_protocol_round_trip.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Async 3 | module Protocol = Postgres_async.Private.Protocol 4 | 5 | let roundtrip ~write ~read = 6 | let%bind `Reader reader_fd, `Writer writer_fd = 7 | Unix.pipe (Info.of_string "roundtrip") 8 | in 9 | let%map () = 10 | let writer = Writer.create writer_fd in 11 | Writer.with_close writer ~f:(fun () -> 12 | write writer; 13 | return ()) 14 | and contents = Reader.contents (Reader.create reader_fd) in 15 | read (Iobuf.of_string contents) 16 | ;; 17 | 18 | let read_message iobuf ~read_message_type_char ~read_payload = 19 | let message_char = 20 | match read_message_type_char with 21 | | true -> Iobuf.Consume.char iobuf |> Option.some 22 | | false -> None 23 | in 24 | print_s [%message (message_char : Char.t option)]; 25 | let length = Iobuf.Consume.int32_be iobuf in 26 | print_s [%message (length : int)]; 27 | (* Length includes itself*) 28 | let payload_length = length - 4 in 29 | read_payload ~payload_length iobuf 30 | ;; 31 | 32 | let test_startup ?replication ?options runtime_parameters = 33 | let startup_message = 34 | Protocol.Frontend.StartupMessage.create_exn 35 | () 36 | ~user:"test-user" 37 | ~database:"testdb" 38 | ?options 39 | ?replication 40 | ~runtime_parameters:(String.Map.of_alist_exn runtime_parameters) 41 | in 42 | let write writer = Protocol.Frontend.Writer.startup_message writer startup_message in 43 | let read_payload ~payload_length:_ iobuf = 44 | let read = Protocol.Frontend.StartupMessage.consume iobuf |> ok_exn in 45 | [%test_result: Protocol.Frontend.StartupMessage.t] read ~expect:startup_message; 46 | print_s [%sexp (read : Protocol.Frontend.StartupMessage.t)] 47 | in 48 | let read = read_message ~read_message_type_char:false ~read_payload in 49 | roundtrip ~write ~read 50 | ;; 51 | 52 | let%expect_test "Simple Startup" = 53 | let%bind () = test_startup [] in 54 | [%expect 55 | {| 56 | (message_char ()) 57 | (length 40) 58 | ((database testdb) (user test-user)) 59 | |}]; 60 | Deferred.unit 61 | ;; 62 | 63 | let%expect_test "Replication Startup" = 64 | let%bind () = test_startup ~replication:"database" [] in 65 | [%expect 66 | {| 67 | (message_char ()) 68 | (length 61) 69 | ((database testdb) (replication database) (user test-user)) 70 | |}]; 71 | Deferred.unit 72 | ;; 73 | 74 | let%expect_test "Startup with Params" = 75 | let%bind () = test_startup [ "client_encoding", "UTF8"; "DateStyle", "ISO,MDY" ] in 76 | [%expect 77 | {| 78 | (message_char ()) 79 | (length 79) 80 | ((DateStyle ISO,MDY) (client_encoding UTF8) (database testdb) 81 | (user test-user)) 82 | |}]; 83 | Deferred.unit 84 | ;; 85 | 86 | (* Note : options is deprecated in favor of runtime parameters according 87 | to postgres docs *) 88 | 89 | let%expect_test "Startup with Options" = 90 | let%bind () = 91 | test_startup [] ~options:[ "client_encoding=UTF8"; "statement_timeout=3000" ] 92 | in 93 | [%expect 94 | {| 95 | (message_char ()) 96 | (length 92) 97 | ((database testdb) (options "client_encoding=UTF8 statement_timeout=3000") 98 | (user test-user)) 99 | |}]; 100 | Deferred.unit 101 | ;; 102 | 103 | let%expect_test "Startup with Options with space" = 104 | let%bind () = 105 | test_startup 106 | [] 107 | ~options:[ "application_name='More Spaces Please'"; "statement_timeout=3000" ] 108 | in 109 | [%expect 110 | {| 111 | (message_char ()) 112 | (length 111) 113 | ((database testdb) 114 | (options "application_name='More\\ Spaces\\ Please' statement_timeout=3000") 115 | (user test-user)) 116 | |}]; 117 | Deferred.unit 118 | ;; 119 | 120 | let%expect_test "Startup with Options with multi space and backslash" = 121 | let%bind () = 122 | test_startup 123 | [] 124 | ~options: 125 | [ "--application_name='My \\very complicated a\\p\\p name'" 126 | ; "--statement_timeout=3000" 127 | ] 128 | in 129 | [%expect 130 | {| 131 | (message_char ()) 132 | (length 137) 133 | ((database testdb) 134 | (options 135 | "--application_name='My\\ \\\\very\\ complicated\\ \\ \\ a\\\\p\\\\p\\ name' --statement_timeout=3000") 136 | (user test-user)) 137 | |}]; 138 | Deferred.unit 139 | ;; 140 | 141 | let%expect_test "KRB password message " = 142 | let write writer = 143 | Protocol.Frontend.Writer.password_message writer (Gss_binary_blob "blob") 144 | in 145 | let read_payload ~payload_length iobuf = 146 | let msg = 147 | Protocol.Frontend.PasswordMessage.consume_krb iobuf ~length:payload_length 148 | |> Or_error.ok_exn 149 | in 150 | match (msg : Protocol.Frontend.PasswordMessage.t) with 151 | | Gss_binary_blob blob -> print_s [%message (blob : string)] 152 | | Cleartext_or_md5_hex _ -> raise_s [%message "Expected krb"] 153 | in 154 | let read = read_message ~read_message_type_char:true ~read_payload in 155 | let%bind () = roundtrip ~write ~read in 156 | [%expect 157 | {| 158 | (message_char (p)) 159 | (length 8) 160 | (blob blob) 161 | |}]; 162 | Deferred.unit 163 | ;; 164 | 165 | let%expect_test "Password password message " = 166 | let write writer = 167 | Protocol.Frontend.Writer.password_message writer (Cleartext_or_md5_hex "hex") 168 | in 169 | let read_payload ~payload_length:_ iobuf = 170 | let msg = 171 | Protocol.Frontend.PasswordMessage.consume_password iobuf |> Or_error.ok_exn 172 | in 173 | match (msg : Protocol.Frontend.PasswordMessage.t) with 174 | | Gss_binary_blob _ -> raise_s [%message "Got krb"] 175 | | Cleartext_or_md5_hex blob -> print_s [%message (blob : string)] 176 | in 177 | let read = read_message ~read_message_type_char:true ~read_payload in 178 | let%bind () = roundtrip ~write ~read in 179 | [%expect 180 | {| 181 | (message_char (p)) 182 | (length 8) 183 | (blob hex) 184 | |}]; 185 | Deferred.unit 186 | ;; 187 | 188 | let%expect_test "Authentication Request message " = 189 | let run sent = 190 | let write writer = Protocol.Backend.Writer.auth_message writer sent in 191 | let read_payload ~payload_length:_ iobuf = 192 | let msg = Protocol.Backend.AuthenticationRequest.consume iobuf |> Or_error.ok_exn in 193 | print_s 194 | [%message 195 | (sent : Protocol.Backend.AuthenticationRequest.t) 196 | (msg : Protocol.Backend.AuthenticationRequest.t)] 197 | in 198 | let read = read_message ~read_message_type_char:true ~read_payload in 199 | roundtrip ~write ~read 200 | in 201 | let%bind () = 202 | Deferred.List.iter 203 | ~how:`Sequential 204 | ~f:run 205 | Protocol.Backend.AuthenticationRequest. 206 | [ Ok 207 | ; KerberosV5 208 | ; CleartextPassword 209 | ; MD5Password { salt = "salt" } 210 | ; SCMCredential 211 | ; GSS 212 | ; SSPI 213 | ; GSSContinue { data = "gsscont" } 214 | ] 215 | in 216 | [%expect 217 | {| 218 | (message_char (R)) 219 | (length 8) 220 | ((sent Ok) (msg Ok)) 221 | (message_char (R)) 222 | (length 8) 223 | ((sent KerberosV5) (msg KerberosV5)) 224 | (message_char (R)) 225 | (length 8) 226 | ((sent CleartextPassword) (msg CleartextPassword)) 227 | (message_char (R)) 228 | (length 12) 229 | ((sent (MD5Password (salt salt))) (msg (MD5Password (salt salt)))) 230 | (message_char (R)) 231 | (length 8) 232 | ((sent SCMCredential) (msg SCMCredential)) 233 | (message_char (R)) 234 | (length 8) 235 | ((sent GSS) (msg GSS)) 236 | (message_char (R)) 237 | (length 8) 238 | ((sent SSPI) (msg SSPI)) 239 | (message_char (R)) 240 | (length 15) 241 | ((sent (GSSContinue (data gsscont))) (msg (GSSContinue (data gsscont)))) 242 | |}]; 243 | Deferred.unit 244 | ;; 245 | 246 | let%expect_test "Authentication Request message " = 247 | let run sent = 248 | let write writer = Protocol.Backend.Writer.ready_for_query writer sent in 249 | let read_payload ~payload_length:_ iobuf = 250 | let msg = Protocol.Backend.ReadyForQuery.consume iobuf |> Or_error.ok_exn in 251 | print_s 252 | [%message 253 | (sent : Protocol.Backend.ReadyForQuery.t) 254 | (msg : Protocol.Backend.ReadyForQuery.t)] 255 | in 256 | let read = read_message ~read_message_type_char:true ~read_payload in 257 | roundtrip ~write ~read 258 | in 259 | let%bind () = 260 | Deferred.List.iter 261 | ~how:`Sequential 262 | ~f:run 263 | Protocol.Backend.ReadyForQuery.[ Idle; In_transaction; In_failed_transaction ] 264 | in 265 | [%expect 266 | {| 267 | (message_char (Z)) 268 | (length 5) 269 | ((sent Idle) (msg Idle)) 270 | (message_char (Z)) 271 | (length 5) 272 | ((sent In_transaction) (msg In_transaction)) 273 | (message_char (Z)) 274 | (length 5) 275 | ((sent In_failed_transaction) (msg In_failed_transaction)) 276 | |}]; 277 | Deferred.unit 278 | ;; 279 | 280 | let%expect_test "BackendKeyData " = 281 | let write writer = 282 | Protocol.Backend.Writer.backend_key writer { pid = 1234; secret = 4444 } 283 | in 284 | let read_payload ~payload_length:_ iobuf = 285 | let ({ pid; secret } : Protocol.Backend.BackendKeyData.t) = 286 | Protocol.Backend.BackendKeyData.consume iobuf |> Or_error.ok_exn 287 | in 288 | print_s [%message (pid : int) (secret : int)] 289 | in 290 | let read = read_message ~read_message_type_char:true ~read_payload in 291 | let%bind () = roundtrip ~write ~read in 292 | [%expect 293 | {| 294 | (message_char (K)) 295 | (length 12) 296 | ((pid 1234) (secret 4444)) 297 | |}]; 298 | Deferred.unit 299 | ;; 300 | 301 | let%expect_test "Paramter Status " = 302 | let write writer = 303 | Protocol.Backend.Writer.parameter_status 304 | writer 305 | { key = "server_version"; data = "13.6" } 306 | in 307 | let read_payload ~payload_length:_ iobuf = 308 | let ({ key; data } : Protocol.Backend.ParameterStatus.t) = 309 | Protocol.Backend.ParameterStatus.consume iobuf |> Or_error.ok_exn 310 | in 311 | print_s [%message (key : string) (data : string)] 312 | in 313 | let read = read_message ~read_message_type_char:true ~read_payload in 314 | let%bind () = roundtrip ~write ~read in 315 | [%expect 316 | {| 317 | (message_char (S)) 318 | (length 24) 319 | ((key server_version) (data 13.6)) 320 | |}]; 321 | Deferred.unit 322 | ;; 323 | 324 | let%expect_test "Error Response " = 325 | let write writer = 326 | Protocol.Backend.Writer.error_response 327 | writer 328 | { error_code = "28000"; all_fields = [ Severity, "FATAL"; Message, "auth-error" ] } 329 | in 330 | let read_payload ~payload_length:_ iobuf = 331 | let ({ error_code; all_fields } : Protocol.Backend.ErrorResponse.t) = 332 | Protocol.Backend.ErrorResponse.consume iobuf |> Or_error.ok_exn 333 | in 334 | print_s 335 | [%message 336 | (error_code : string) 337 | (all_fields : (Protocol.Backend.Error_or_notice_field.t * string) list)] 338 | in 339 | let read = read_message ~read_message_type_char:true ~read_payload in 340 | let%bind () = roundtrip ~write ~read in 341 | [%expect 342 | {| 343 | (message_char (E)) 344 | (length 31) 345 | ((error_code 28000) 346 | (all_fields ((Code 28000) (Severity FATAL) (Message auth-error)))) 347 | |}]; 348 | Deferred.unit 349 | ;; 350 | 351 | let%expect_test "Data Row" = 352 | let write writer = 353 | Protocol.Backend.Writer.data_row 354 | writer 355 | (Iarray.of_list [ Some "Col 1"; Some "Col 2"; None; Some "Col 4" ]) 356 | in 357 | let read_payload ~payload_length:_ iobuf = 358 | let (row : string option iarray) = 359 | Protocol.Backend.DataRow.consume iobuf |> Or_error.ok_exn 360 | in 361 | print_s [%message (row : string option iarray)] 362 | in 363 | let read = read_message ~read_message_type_char:true ~read_payload in 364 | let%bind () = roundtrip ~write ~read in 365 | [%expect 366 | {| 367 | (message_char (D)) 368 | (length 37) 369 | (row (("Col 1") ("Col 2") () ("Col 4"))) 370 | |}]; 371 | Deferred.unit 372 | ;; 373 | 374 | let%expect_test "Command Complete" = 375 | let write writer = Protocol.Backend.Writer.command_complete writer "DELETE 10" in 376 | let read_payload ~payload_length:_ iobuf = 377 | let (response : string) = 378 | Protocol.Backend.CommandComplete.consume iobuf |> Or_error.ok_exn 379 | in 380 | print_s [%message (response : string)] 381 | in 382 | let read = read_message ~read_message_type_char:true ~read_payload in 383 | let%bind () = roundtrip ~write ~read in 384 | [%expect 385 | {| 386 | (message_char (C)) 387 | (length 14) 388 | (response "DELETE 10") 389 | |}]; 390 | Deferred.unit 391 | ;; 392 | 393 | let%expect_test "Notice Response" = 394 | let write writer = 395 | Protocol.Backend.Writer.notice_response 396 | writer 397 | { error_code = "00000"; all_fields = [ Severity, "LOG" ] } 398 | in 399 | let read_payload ~payload_length:_ iobuf = 400 | let (response : Protocol.Backend.NoticeResponse.t) = 401 | Protocol.Backend.NoticeResponse.consume iobuf |> Or_error.ok_exn 402 | in 403 | print_s [%message (response : Protocol.Backend.NoticeResponse.t)] 404 | in 405 | let read = read_message ~read_message_type_char:true ~read_payload in 406 | let%bind () = roundtrip ~write ~read in 407 | [%expect 408 | {| 409 | (message_char (N)) 410 | (length 17) 411 | (response ((error_code 00000) (all_fields ((Code 00000) (Severity LOG))))) 412 | |}]; 413 | Deferred.unit 414 | ;; 415 | 416 | let%expect_test "Notification Response" = 417 | let write writer = 418 | Protocol.Backend.Writer.notification_response 419 | writer 420 | { pid = Pid.of_int 10 421 | ; channel = 422 | Postgres_async.Private.Types.Notification_channel.of_string "Test Channel" 423 | ; payload = "Test Payload" 424 | } 425 | in 426 | let read_payload ~payload_length:_ iobuf = 427 | let (response : Protocol.Backend.NotificationResponse.t) = 428 | Protocol.Backend.NotificationResponse.consume iobuf |> Or_error.ok_exn 429 | in 430 | print_s [%message (response : Protocol.Backend.NotificationResponse.t)] 431 | in 432 | let read = read_message ~read_message_type_char:true ~read_payload in 433 | let%bind () = roundtrip ~write ~read in 434 | [%expect 435 | {| 436 | (message_char (A)) 437 | (length 34) 438 | (response ((pid 10) (channel "Test Channel") (payload "Test Payload"))) 439 | |}]; 440 | Deferred.unit 441 | ;; 442 | 443 | let%expect_test "Parameter Description" = 444 | let write writer = 445 | Protocol.Backend.Writer.parameter_description writer (Array.of_list [ 12; 23; 1; 5 ]) 446 | in 447 | let read_payload ~payload_length:_ iobuf = 448 | let (response : int array) = 449 | Protocol.Backend.ParameterDescription.consume iobuf |> Or_error.ok_exn 450 | in 451 | print_s [%message (response : int array)] 452 | in 453 | let read = read_message ~read_message_type_char:true ~read_payload in 454 | let%bind () = roundtrip ~write ~read in 455 | [%expect 456 | {| 457 | (message_char (t)) 458 | (length 22) 459 | (response (12 23 1 5)) 460 | |}]; 461 | Deferred.unit 462 | ;; 463 | 464 | let%expect_test "Parameter Status" = 465 | let write writer = 466 | Protocol.Backend.Writer.parameter_status writer { key = "user"; data = "root" } 467 | in 468 | let read_payload ~payload_length:_ iobuf = 469 | let (response : Protocol.Backend.ParameterStatus.t) = 470 | Protocol.Backend.ParameterStatus.consume iobuf |> Or_error.ok_exn 471 | in 472 | print_s [%message (response : Protocol.Backend.ParameterStatus.t)] 473 | in 474 | let read = read_message ~read_message_type_char:true ~read_payload in 475 | let%bind () = roundtrip ~write ~read in 476 | [%expect 477 | {| 478 | (message_char (S)) 479 | (length 14) 480 | (response ((key user) (data root))) 481 | |}]; 482 | Deferred.unit 483 | ;; 484 | 485 | let%expect_test "Query" = 486 | let write writer = 487 | Protocol.Frontend.Writer.query writer "SELECT * FROM a; SELECT * FROM a;" 488 | in 489 | let read_payload ~payload_length:_ iobuf = 490 | let (response : string) = Protocol.Frontend.Query.consume iobuf |> Or_error.ok_exn in 491 | print_s [%message response] 492 | in 493 | let read = read_message ~read_message_type_char:true ~read_payload in 494 | let%bind () = roundtrip ~write ~read in 495 | [%expect 496 | {| 497 | (message_char (Q)) 498 | (length 38) 499 | "SELECT * FROM a; SELECT * FROM a;" 500 | |}]; 501 | Deferred.unit 502 | ;; 503 | -------------------------------------------------------------------------------- /test/test_protocol_round_trip.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_query.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_runtime_parameters.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | let () = Dynamic.set_root Backtrace.elide true 5 | let harness = lazy (Harness.create ()) 6 | 7 | let startup_message ?application_name ?replication ?options () = 8 | Postgres_async.Private.Protocol.Frontend.StartupMessage.create_exn 9 | () 10 | ~user:"postgres" 11 | ~database:"postgres" 12 | ?replication 13 | ?options 14 | ~runtime_parameters: 15 | (List.filter_opt 16 | [ Option.map application_name ~f:(fun a -> "application_name", a) 17 | ; Some ("TimeZone", "UTC") (* this is required to stabilize the output *) 18 | ] 19 | |> String.Map.of_alist_exn) 20 | ;; 21 | 22 | let login_and_print_params startup_message = 23 | let harness = force harness in 24 | let conn = 25 | Postgres_async.Private.Without_background_asynchronous_message_handling 26 | .login_and_get_raw 27 | ~server:(Harness.where_to_connect harness) 28 | ~startup_message 29 | () 30 | in 31 | match%bind conn with 32 | | Error error -> 33 | print_s [%message (error : Postgres_async.Pgasync_error.t)]; 34 | Deferred.unit 35 | | Ok conn -> 36 | let params = 37 | Postgres_async.Private.Without_background_asynchronous_message_handling 38 | .runtime_parameters 39 | conn 40 | in 41 | print_s [%message (params : string String.Map.t)]; 42 | Postgres_async.Private.Without_background_asynchronous_message_handling.writer conn 43 | |> Writer.close 44 | ;; 45 | 46 | let%expect_test "get runtime parameters" = 47 | let%bind () = login_and_print_params (startup_message ()) in 48 | [%expect 49 | {| 50 | (params 51 | ((DateStyle "ISO, MDY") (IntervalStyle postgres) (TimeZone UTC) 52 | (application_name "") (client_encoding SQL_ASCII) (integer_datetimes on) 53 | (is_superuser on) (server_encoding SQL_ASCII) (server_version 12.10) 54 | (session_authorization postgres) (standard_conforming_strings on))) 55 | |}]; 56 | return () 57 | ;; 58 | 59 | let%expect_test "set runtime parameters" = 60 | let%bind () = 61 | startup_message ~application_name:"simple_app" () |> login_and_print_params 62 | in 63 | [%expect 64 | {| 65 | (params 66 | ((DateStyle "ISO, MDY") (IntervalStyle postgres) (TimeZone UTC) 67 | (application_name simple_app) (client_encoding SQL_ASCII) 68 | (integer_datetimes on) (is_superuser on) (server_encoding SQL_ASCII) 69 | (server_version 12.10) (session_authorization postgres) 70 | (standard_conforming_strings on))) 71 | |}]; 72 | return () 73 | ;; 74 | 75 | let%expect_test "set options" = 76 | let%bind () = 77 | startup_message 78 | ~options:[ "--application_name=My \\very complicated a\\p\\p name'" ] 79 | () 80 | |> login_and_print_params 81 | in 82 | [%expect 83 | {| 84 | (params 85 | ((DateStyle "ISO, MDY") (IntervalStyle postgres) (TimeZone UTC) 86 | (application_name "My \\very complicated a\\p\\p name'") 87 | (client_encoding SQL_ASCII) (integer_datetimes on) (is_superuser on) 88 | (server_encoding SQL_ASCII) (server_version 12.10) 89 | (session_authorization postgres) (standard_conforming_strings on))) 90 | |}]; 91 | return () 92 | ;; 93 | -------------------------------------------------------------------------------- /test/test_runtime_parameters.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_server_failure.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | open Expect_test_helpers_core 4 | 5 | let harness = lazy (Harness.create ()) 6 | 7 | let print_or_error or_error = 8 | let sexp = 9 | [%sexp (or_error : unit Or_error.t)] |> Utils.delete_unstable_bits_of_error 10 | in 11 | print_s sexp 12 | ;; 13 | 14 | let with_connection ?(show_errors = true) ?(database = "postgres") func = 15 | let sexp_of_error error = 16 | match show_errors with 17 | | false -> [%sexp ""] 18 | | true -> Utils.delete_unstable_bits_of_error [%sexp (error : Error.t)] 19 | in 20 | (* [Postgres_async.with_connection] does not report failure to close the connection 21 | because we think that users would prefer to get the result of their function instead. 22 | But it's interesting to print it. *) 23 | match%bind 24 | Postgres_async.connect 25 | () 26 | ~server:(Harness.where_to_connect (force harness)) 27 | ~user:"postgres" 28 | ~database 29 | with 30 | | Error error -> 31 | print_s [%message "failed to connect" (error : error)]; 32 | return () 33 | | Ok postgres -> 34 | let%bind () = func postgres in 35 | (match%bind Postgres_async.close postgres with 36 | | Ok () -> 37 | print_s [%message "closed cleanly"]; 38 | return () 39 | | Error error -> 40 | print_s [%message "failed to close" (error : error)]; 41 | return ()) 42 | ;; 43 | 44 | let%expect_test "terminate backend" = 45 | (* The easiest way to write this test is to have the backend kill itself. Sadly, if the 46 | backend is killed while it is executing a query, there's a race: either the write of 47 | the [Sync] message to fail ("writer failed asynchronously") or the write succeeds 48 | (getting out the door before the connection is closed) and instead we get unexpected 49 | EOF on the read that follows (seems to in practice be far more likely). 50 | 51 | Therefore, we can't show the contents of the error messages, sadly. But we can assert 52 | that we do get errors. *) 53 | let%bind () = 54 | with_connection ~show_errors:false (fun postgres -> 55 | let%bind result = 56 | Postgres_async.query 57 | postgres 58 | "SELECT pg_terminate_backend(pg_backend_pid())" 59 | ~handle_row:(fun ~column_names:_ ~values:_ -> ()) 60 | in 61 | print_or_error result; 62 | [%expect 63 | {| 64 | (Error ( 65 | (query "SELECT pg_terminate_backend(pg_backend_pid())") 66 | ("Error during query execution (despite parsing ok)" 67 | ((Severity FATAL) 68 | (Code 57P01))))) 69 | |}]; 70 | let%bind result = Postgres_async.query_expect_no_data postgres "" in 71 | [%test_pred: unit Or_error.t] 72 | (fun r -> 73 | String.is_substring 74 | (Sexp.to_string [%sexp (r : unit Or_error.t)]) 75 | ~substring:"query issued against previously-failed connection") 76 | result; 77 | return ()) 78 | in 79 | [%expect {| ("failed to close" (error )) |}]; 80 | (* Here's a test where our backend is externally terminated while we're idle; a little 81 | tricky to get right, but this one is actually deterministic (because we're not trying 82 | to write at the same time as the connection is closed on us). *) 83 | let%bind () = 84 | with_connection (fun postgres -> 85 | let%bind backend_pid = Utils.pg_backend_pid postgres in 86 | let%bind () = 87 | Harness.with_connection_exn (force harness) ~database:"postgres" (fun postgres2 -> 88 | let%bind result = 89 | Postgres_async.query 90 | postgres2 91 | "SELECT pg_terminate_backend($1)" 92 | ~parameters:[| Some backend_pid |] 93 | ~handle_row:(fun ~column_names:_ ~values:_ -> ()) 94 | in 95 | Or_error.ok_exn result; 96 | return ()) 97 | in 98 | let%bind result = Postgres_async.close_finished postgres in 99 | print_or_error result; 100 | [%expect 101 | {| 102 | (Error ( 103 | "ErrorResponse received asynchronously, assuming connection is dead" 104 | ((Severity FATAL) 105 | (Code 57P01)))) 106 | |}]; 107 | let%bind result = Postgres_async.query_expect_no_data postgres "" in 108 | print_or_error result; 109 | [%expect 110 | {| 111 | (Error ( 112 | (query "") 113 | ("query issued against previously-failed connection" 114 | (original_error ( 115 | "ErrorResponse received asynchronously, assuming connection is dead" 116 | ((Severity FATAL) 117 | (Code 57P01))))))) 118 | |}]; 119 | return ()) 120 | in 121 | [%expect 122 | {| 123 | ("failed to close" ( 124 | error ( 125 | "ErrorResponse received asynchronously, assuming connection is dead" 126 | ((Severity FATAL) 127 | (Code 57P01))))) 128 | |}]; 129 | return () 130 | ;; 131 | 132 | let accept_login = "R\x00\x00\x00\x08\x00\x00\x00\x00Z\x00\x00\x00\x05I" 133 | 134 | module Socket_id = Unique_id.Int () 135 | 136 | let with_manual_server ~handle_client ~f:callback = 137 | let socket_name = 138 | sprintf 139 | !"/tmp/.postgres-async-tests-%{Pid}-%{Socket_id}" 140 | (Unix.getpid ()) 141 | (Socket_id.create ()) 142 | in 143 | let%bind tcp_server = 144 | Tcp.Server.create 145 | ~on_handler_error:`Raise 146 | (Tcp.Where_to_listen.of_file socket_name) 147 | (fun _sock reader writer -> handle_client reader writer) 148 | in 149 | let where_to_connect = Tcp.Where_to_connect.of_file socket_name in 150 | let finally () = 151 | let%bind () = Tcp.Server.close tcp_server in 152 | let%bind () = Unix.unlink socket_name in 153 | return () 154 | in 155 | Monitor.protect ~run:`Now ~rest:`Raise ~finally (fun () -> callback where_to_connect) 156 | ;; 157 | 158 | let send_eof writer = 159 | let%bind () = Writer.flushed writer in 160 | Fd.syscall_exn (Writer.fd writer) (Core_unix.shutdown ~mode:SHUTDOWN_SEND); 161 | return () 162 | ;; 163 | 164 | let%expect_test "invaild messages during login" = 165 | let try_connect ?(send_eof_after_response = true) response = 166 | with_manual_server 167 | ~handle_client:(fun reader writer -> 168 | (* wait for the startup message. *) 169 | let%bind _ = Reader.read_char reader in 170 | (* pre-prepared binary response *) 171 | Writer.write writer response; 172 | let%bind () = 173 | match send_eof_after_response with 174 | | true -> send_eof writer 175 | | false -> return () 176 | in 177 | Deferred.never ()) 178 | ~f:(fun where_to_connect -> 179 | let%bind result = 180 | Postgres_async.with_connection 181 | ~server:where_to_connect 182 | ~user:"postgres" 183 | ~database:"postgres" 184 | ~on_handler_exception:`Raise 185 | (fun _ -> return ()) 186 | in 187 | print_s ~hide_positions:true [%sexp (result : _ Or_error.t)]; 188 | return ()) 189 | in 190 | (* demonstrate that our TCP server works: *) 191 | let%bind () = try_connect accept_login in 192 | [%expect {| (Ok _) |}]; 193 | (* nonsense message length. *) 194 | let%bind () = try_connect "R\x00\x00\x00\x03" in 195 | [%expect {| (Error ("Nonsense message length in header" 4)) |}]; 196 | (* authentication method we don't recognise *) 197 | let%bind () = try_connect "R\x00\x00\x00\x08\x00\x00\x00\xFF" in 198 | [%expect 199 | {| 200 | (Error ( 201 | "Failed to parse AuthenticationRequest" ( 202 | exn ("AuthenticationRequest unrecognised type" (other 255))))) 203 | |}]; 204 | (* message type we don't recognise *) 205 | let%bind () = try_connect "x\x00\x00\x00\x04" in 206 | [%expect {| (Error ("Unrecognised message type character" (other x))) |}]; 207 | (* message type we recognise but don't expect *) 208 | let%bind () = try_connect "s\x00\x00\x00\x04" in 209 | [%expect 210 | {| 211 | (Error ( 212 | "Unexpected message type" 213 | (msg_type PortalSuspended) 214 | (state "logging in") 215 | (here lib/postgres_async/src/postgres_async.ml:LINE:COL))) 216 | |}]; 217 | (* very long message *) 218 | let%bind () = 219 | try_connect ~send_eof_after_response:false "R\x40\x00\x00\x04 blah blah blah" 220 | in 221 | [%expect {| (Error ("Message too long" (message_length 1073741829))) |}]; 222 | (* message truncated by EOF *) 223 | let%bind () = try_connect "R\x00\x00\x00\x05" in 224 | [%expect {| (Error ("Unexpected EOF" (unconsumed_bytes 5))) |}]; 225 | return () 226 | ;; 227 | 228 | let%expect_test "invalid messages during query_expect_no_data" = 229 | let try_query 230 | ?(show_second_result = false) 231 | ?(send_eof_after_response = true) 232 | query_response 233 | = 234 | let handle_client reader writer = 235 | (* wait for the startup message. *) 236 | let scratch = Bytes.create 41 in 237 | let%bind () = 238 | match%bind Reader.really_read reader scratch with 239 | | `Ok -> return () 240 | | `Eof _ -> assert false 241 | in 242 | let scratch = Bytes.to_string scratch in 243 | (* startup messages end with two nulls *) 244 | assert (Char.equal scratch.[39] '\x00' && Char.equal scratch.[40] '\x00'); 245 | Writer.write writer accept_login; 246 | (* wait for the query to start executing. *) 247 | let%bind parse = 248 | match%bind Reader.read_char reader with 249 | | `Ok c -> return c 250 | | `Eof -> assert false 251 | in 252 | (* the first message should be a parse. *) 253 | assert (Char.equal parse 'P'); 254 | Writer.write writer query_response; 255 | let%bind () = 256 | match send_eof_after_response with 257 | | true -> send_eof writer 258 | | false -> return () 259 | in 260 | Deferred.never () 261 | in 262 | let send_queries postgres = 263 | let handle_row ~column_names ~values = 264 | let sexp_of_sopt = function 265 | | None -> [%sexp ""] 266 | | Some s -> [%sexp (s : string)] 267 | in 268 | print_s [%message "row" (column_names : string iarray) (values : sopt iarray)] 269 | in 270 | let%bind r1 = Postgres_async.query postgres "" ~handle_row in 271 | print_s ~hide_positions:true [%message (r1 : _ Or_error.t)]; 272 | let%bind () = 273 | match r1 with 274 | | Ok () -> return () 275 | | Error _ -> 276 | (match%bind Postgres_async.close_finished postgres with 277 | | Ok _ -> failwith "close_finished returned Ok _ ?" 278 | | Error _ -> 279 | printf "close_finished is determined with an error\n"; 280 | (match%bind Postgres_async.query postgres "" ~handle_row with 281 | | Ok _ -> failwith "second query succeeded?" 282 | | Error _ as r2 -> 283 | (match show_second_result with 284 | | false -> () 285 | | true -> print_s ~hide_positions:true [%message (r2 : _ Or_error.t)]); 286 | return ())) 287 | in 288 | return () 289 | in 290 | with_manual_server ~handle_client ~f:(fun where_to_connect -> 291 | let%bind outer_result = 292 | Postgres_async.with_connection 293 | ~server:where_to_connect 294 | ~user:"postgres" 295 | ~database:"postgres" 296 | ~on_handler_exception:`Raise 297 | send_queries 298 | in 299 | print_s [%message (outer_result : _ Or_error.t)]; 300 | return ()) 301 | in 302 | (* demonstrate that our TCP server works: *) 303 | let parsecomplete = "1\x00\x00\x00\x04" in 304 | let bindcomplete = "2\x00\x00\x00\x04" in 305 | let nodata = "n\x00\x00\x00\x04" in 306 | let emptyqueryresponse = "I\x00\x00\x00\x04" in 307 | let rowdescription = "T\x00\x00\x00\x06\x00\x00" in 308 | let commandcomplete = "C\x00\x00\x00\x06_\x00" in 309 | let readyforquery = "Z\x00\x00\x00\x05I" in 310 | let%bind () = 311 | try_query 312 | ~show_second_result:true 313 | (parsecomplete ^ bindcomplete ^ nodata ^ emptyqueryresponse ^ readyforquery) 314 | in 315 | [%expect 316 | {| 317 | (r1 (Ok _)) 318 | (outer_result (Ok _)) 319 | |}]; 320 | (* unexpected or unrecognised message types at various stages. *) 321 | let%bind () = try_query "x\x00\x00\x00\x04" in 322 | [%expect 323 | {| 324 | (r1 ( 325 | Error ((query ) ("Unrecognised message type character" (other x))))) 326 | close_finished is determined with an error 327 | (outer_result (Ok _)) 328 | |}]; 329 | let%bind () = try_query (parsecomplete ^ "s\x00\x00\x00\x04") in 330 | [%expect 331 | {| 332 | (r1 ( 333 | Error ( 334 | (query ) 335 | ("Unexpected message type" 336 | (msg_type PortalSuspended) 337 | (state Binding) 338 | (here lib/postgres_async/src/postgres_async.ml:LINE:COL))))) 339 | close_finished is determined with an error 340 | (outer_result (Ok _)) 341 | |}]; 342 | let%bind () = try_query (parsecomplete ^ bindcomplete ^ "x\x00\x00\x00\x04") in 343 | [%expect 344 | {| 345 | (r1 ( 346 | Error ((query ) ("Unrecognised message type character" (other x))))) 347 | close_finished is determined with an error 348 | (outer_result (Ok _)) 349 | |}]; 350 | let%bind () = 351 | try_query (parsecomplete ^ bindcomplete ^ rowdescription ^ "s\x00\x00\x00\x04") 352 | in 353 | [%expect 354 | {| 355 | (r1 ( 356 | Error ( 357 | (query ) 358 | ("Unexpected message type" 359 | (msg_type PortalSuspended) 360 | (state "reading DataRows") 361 | (here lib/postgres_async/src/postgres_async.ml:LINE:COL))))) 362 | close_finished is determined with an error 363 | (outer_result (Ok _)) 364 | |}]; 365 | (* second query fails cleanly too: *) 366 | let%bind () = try_query "x\x00\x00\x00\x04" ~show_second_result:true in 367 | [%expect 368 | {| 369 | (r1 ( 370 | Error ((query ) ("Unrecognised message type character" (other x))))) 371 | close_finished is determined with an error 372 | (r2 ( 373 | Error ( 374 | (query ) 375 | ("query issued against previously-failed connection" 376 | (original_error ("Unrecognised message type character" (other x))))))) 377 | (outer_result (Ok _)) 378 | |}]; 379 | (* rowdescription with junk in *) 380 | let%bind () = try_query (parsecomplete ^ bindcomplete ^ "T\x00\x00\x00\x06\x0f\xff") in 381 | [%expect 382 | {| 383 | (r1 ( 384 | Error ( 385 | (query ) 386 | ("Failed to parse RowDescription" ( 387 | exn ( 388 | "Iobuf got invalid range" ( 389 | ((pos 0) 390 | (len 1)) 391 | ((buf ) 392 | (lo_min 0) 393 | (lo 17) 394 | (hi 17) 395 | (hi_max 17))))))))) 396 | close_finished is determined with an error 397 | (outer_result (Ok _)) 398 | |}]; 399 | (* datarow with mismatched number of fields vs. row desc. *) 400 | let%bind () = 401 | let z18 = String.init 18 ~f:(fun _ -> '\x00') in 402 | try_query 403 | (parsecomplete 404 | ^ bindcomplete 405 | ^ "T\x00\x00\x00\x2e\x00\x02A\x00" 406 | ^ z18 407 | ^ "B\x00" 408 | ^ z18 409 | ^ "D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x01a" 410 | ^ commandcomplete 411 | ^ readyforquery) 412 | in 413 | [%expect 414 | {| 415 | (r1 ( 416 | Error ( 417 | (query ) 418 | ("number of columns in DataRow message did not match RowDescription" 419 | (column_names (A B)) 420 | (values ((a))))))) 421 | close_finished is determined with an error 422 | (outer_result (Ok _)) 423 | |}]; 424 | (* very long datarow *) 425 | let%bind () = 426 | try_query 427 | ~send_eof_after_response:false 428 | (parsecomplete ^ bindcomplete ^ rowdescription ^ "D\x40\x00\x00\x04 blah blah blah") 429 | in 430 | [%expect 431 | {| 432 | (r1 (Error ((query ) ("Message too long" (message_length 1073741829))))) 433 | close_finished is determined with an error 434 | (outer_result (Ok _)) 435 | |}]; 436 | (* message truncated by EOF *) 437 | let%bind () = 438 | try_query (parsecomplete ^ bindcomplete ^ rowdescription ^ "D\x00\x00\x00\x10") 439 | in 440 | [%expect 441 | {| 442 | (r1 (Error ((query ) ("Unexpected EOF" (unconsumed_bytes 5))))) 443 | close_finished is determined with an error 444 | (outer_result (Ok _)) 445 | |}]; 446 | return () 447 | ;; 448 | 449 | let%expect_test "asynchronous EOF" = 450 | let send_eof_now = Ivar.create () in 451 | let handle_client reader writer = 452 | (* wait for the startup message. *) 453 | let%bind _ = Reader.read_char reader in 454 | (* accept login *) 455 | Writer.write writer accept_login; 456 | (* send EOF *) 457 | let%bind () = Ivar.read send_eof_now in 458 | let%bind () = send_eof writer in 459 | Deferred.never () 460 | in 461 | with_manual_server ~handle_client ~f:(fun where_to_connect -> 462 | let handle_connection postgres = 463 | let close_finished_deferred = Postgres_async.close_finished postgres in 464 | let%bind () = Scheduler.yield_until_no_jobs_remain () in 465 | printf "in connection handler\n"; 466 | print_s [%message (close_finished_deferred : unit Or_error.t Deferred.t)]; 467 | printf "now sending EOF\n"; 468 | Ivar.fill_exn send_eof_now (); 469 | (* postgres_async should notice asynchronously, without issuing a query. *) 470 | let%bind close_finished = close_finished_deferred in 471 | print_s [%message (close_finished : unit Or_error.t)]; 472 | return () 473 | in 474 | let%bind outer_result = 475 | Postgres_async.with_connection 476 | ~server:where_to_connect 477 | ~user:"postgres" 478 | ~database:"postgres" 479 | ~on_handler_exception:`Raise 480 | handle_connection 481 | in 482 | print_s [%message (outer_result : _ Or_error.t)]; 483 | [%expect 484 | {| 485 | in connection handler 486 | (close_finished_deferred Empty) 487 | now sending EOF 488 | (close_finished (Error "Unexpected EOF (no unconsumed messages)")) 489 | (outer_result (Ok _)) 490 | |}]; 491 | return ()) 492 | ;; 493 | 494 | let%expect_test "asynchronous writer failure during login" = 495 | (* note that SHUTDOWN_RECV doesn't work on TCP; fortunately we have unix sockets. *) 496 | with_manual_server 497 | ~handle_client:(fun reader writer -> 498 | (* wait for the startup message. *) 499 | let%bind _ = Reader.read_char reader in 500 | (* shutdown read. *) 501 | Fd.syscall_exn (Reader.fd reader) (Core_unix.shutdown ~mode:SHUTDOWN_RECEIVE); 502 | (* ask for a password. *) 503 | Writer.write writer "R\x00\x00\x00\x0c\x00\x00\x00\x05salt"; 504 | Deferred.never ()) 505 | ~f:(fun where_to_connect -> 506 | let%bind result = 507 | Postgres_async.connect 508 | () 509 | ~server:where_to_connect 510 | ~user:"postgres" 511 | ~database:"postgres" 512 | ~password:"postgres" 513 | in 514 | let result = Utils.delete_unstable_bits_of_error [%sexp (result : _ Or_error.t)] in 515 | print_s result; 516 | [%expect 517 | {| 518 | (Error ( 519 | "Writer failed asynchronously" ( 520 | exn ( 521 | monitor.ml.Error 522 | ("Writer error from inner_monitor" 523 | (Unix.Unix_error "Broken pipe" writev_assume_fd_is_nonblocking "") 524 | ) 525 | ("Caught by monitor Writer.monitor"))))) 526 | |}]; 527 | return ()) 528 | ;; 529 | 530 | let%expect_test "asynchronous writer failure during query" = 531 | with_manual_server 532 | ~handle_client:(fun reader writer -> 533 | (* login. *) 534 | let%bind _ = Reader.read_char reader in 535 | Writer.write writer accept_login; 536 | (* shutdown read before the query is sent. *) 537 | Fd.syscall_exn (Reader.fd reader) (Core_unix.shutdown ~mode:SHUTDOWN_RECEIVE); 538 | Deferred.never ()) 539 | ~f:(fun where_to_connect -> 540 | let handle_connection postgres = 541 | printf "connected\n"; 542 | let%bind result = Postgres_async.query_expect_no_data postgres "" in 543 | let result = [%message (result : _ Or_error.t)] in 544 | print_s (Utils.delete_unstable_bits_of_error result); 545 | return () 546 | in 547 | let%bind outer_result = 548 | Postgres_async.with_connection 549 | ~server:where_to_connect 550 | ~user:"postgres" 551 | ~database:"postgres" 552 | ~password:"postgres" 553 | ~on_handler_exception:`Raise 554 | handle_connection 555 | in 556 | print_s [%message (outer_result : unit Or_error.t)]; 557 | [%expect 558 | {| 559 | connected 560 | (result ( 561 | Error ( 562 | (query ) 563 | ("Writer failed asynchronously" ( 564 | exn ( 565 | monitor.ml.Error 566 | ("Writer error from inner_monitor" 567 | (Unix.Unix_error "Broken pipe" writev_assume_fd_is_nonblocking "") 568 | ) 569 | ("Caught by monitor Writer.monitor"))))))) 570 | (outer_result (Ok ())) 571 | |}]; 572 | return ()) 573 | ;; 574 | -------------------------------------------------------------------------------- /test/test_server_failure.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_simple_query.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_smoke.ml: -------------------------------------------------------------------------------- 1 | open! Core 2 | open Async 3 | 4 | let harness = lazy (Harness.create ()) 5 | 6 | let%expect_test "check that basic query functionality works" = 7 | Harness.with_connection_exn (force harness) ~database:"postgres" (fun postgres -> 8 | let%bind result = 9 | Postgres_async.query 10 | postgres 11 | "SELECT $1::int" 12 | ~parameters:[| Some "1234" |] 13 | ~handle_row:(fun ~column_names:_ ~values -> 14 | print_s [%message (values : string option iarray)]) 15 | in 16 | Or_error.ok_exn result; 17 | [%expect {| (values ((1234))) |}]; 18 | return ()) 19 | ;; 20 | 21 | let%expect_test "check that fundamental copy-in features work" = 22 | Harness.with_connection_exn (force harness) ~database:"postgres" (fun postgres -> 23 | let%bind result = 24 | Postgres_async.query_expect_no_data 25 | postgres 26 | "CREATE TEMPORARY TABLE x ( y integer PRIMARY KEY, z text )" 27 | in 28 | Or_error.ok_exn result; 29 | [%expect {| |}]; 30 | let%bind result = 31 | let countdown = ref 10 in 32 | Postgres_async.copy_in_rows 33 | postgres 34 | ~table_name:"x" 35 | ~column_names:[ "y"; "z" ] 36 | ~feed_data:(fun () -> 37 | match !countdown with 38 | | 0 -> Finished 39 | | i -> 40 | decr countdown; 41 | Data 42 | [| Some (Int.to_string i) 43 | ; Option.some_if (i % 2 = 0) (sprintf "asdf-%i" i) 44 | |]) 45 | in 46 | Or_error.ok_exn result; 47 | [%expect {| |}]; 48 | let%bind result = 49 | Postgres_async.query 50 | postgres 51 | "SELECT * FROM x ORDER BY y" 52 | ~handle_row:(fun ~column_names ~values -> 53 | print_s 54 | [%sexp (Iarray.zip_exn column_names values : (string * string option) iarray)]) 55 | in 56 | Or_error.ok_exn result; 57 | [%expect 58 | {| 59 | ((y (1)) (z ())) 60 | ((y (2)) (z (asdf-2))) 61 | ((y (3)) (z ())) 62 | ((y (4)) (z (asdf-4))) 63 | ((y (5)) (z ())) 64 | ((y (6)) (z (asdf-6))) 65 | ((y (7)) (z ())) 66 | ((y (8)) (z (asdf-8))) 67 | ((y (9)) (z ())) 68 | ((y (10)) (z (asdf-10))) 69 | |}]; 70 | return ()) 71 | ;; 72 | -------------------------------------------------------------------------------- /test/test_smoke.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/test_ssl.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Async 3 | module Socket_id = Unique_id.Int () 4 | 5 | let () = Dynamic.set_root Backtrace.elide true 6 | let accept_login = "R\x00\x00\x00\x08\x00\x00\x00\x00Z\x00\x00\x00\x05I" 7 | 8 | let with_manual_server ~handle_client ~f:callback = 9 | let socket_name = 10 | sprintf 11 | !"/tmp/.postgres-async-tests-%{Pid}-%{Socket_id}" 12 | (Unix.getpid ()) 13 | (Socket_id.create ()) 14 | in 15 | let%bind tcp_server = 16 | Tcp.Server.create 17 | ~on_handler_error:`Raise 18 | (Tcp.Where_to_listen.of_file socket_name) 19 | (fun _sock reader writer -> handle_client reader writer) 20 | in 21 | let where_to_connect = Tcp.Where_to_connect.of_file socket_name in 22 | let finally () = 23 | let%bind () = Tcp.Server.close tcp_server in 24 | let%bind () = Unix.unlink socket_name in 25 | return () 26 | in 27 | Monitor.protect ~run:`Now ~rest:`Raise ~finally (fun () -> callback where_to_connect) 28 | ;; 29 | 30 | let startup_message_length = 41 31 | 32 | let handle_startup_login_and_close reader writer = 33 | let process_close_message reader = 34 | (* Close message is 'X' *) 35 | match%bind Reader.read_char reader with 36 | | `Ok 'X' -> Writer.close writer 37 | | `Ok _ | `Eof -> 38 | (* Even if we don't get a close message here, we still close the writer. 39 | This can happen in the "SSL negotiation failure" test below. *) 40 | Writer.close writer 41 | in 42 | (* Client may decide to send startup message or close its connection based on the ssl 43 | response. *) 44 | let scratch = Bytes.create startup_message_length in 45 | match%bind Reader.really_read reader scratch with 46 | | `Eof _ -> Writer.close writer 47 | | `Ok -> 48 | Writer.write writer accept_login; 49 | process_close_message reader 50 | ;; 51 | 52 | let crt_file = "server-leaf_certificate.pem" 53 | let key_file = "server-leaf_key.key" 54 | 55 | let ssl_server_conf () = 56 | (* Took these files & config [Async_ssl] (test/lib/ subfolder). *) 57 | let ca_file = Some "server-leaf_certificate.pem" in 58 | let ca_path = None in 59 | Async_ssl.Config.Server.create ~crt_file ~key_file ~ca_file ~ca_path () 60 | ;; 61 | 62 | let handle_client ?(force_non_ssl = false) ~ssl_char reader writer = 63 | (* SSLRequest is always length 8 *) 64 | let%bind () = 65 | match%map Reader.peek reader ~len:8 with 66 | | `Ok _ -> () 67 | | `Eof -> failwith "unexpected eof in init message" 68 | in 69 | match Reader.bytes_available reader with 70 | | 8 -> 71 | let scratch = Bytes.create 8 in 72 | let%bind () = 73 | match%bind Reader.really_read reader scratch with 74 | | `Ok -> return () 75 | | `Eof _ -> assert false 76 | in 77 | Writer.write_char writer ssl_char; 78 | (match ssl_char with 79 | | 'S' -> 80 | (match force_non_ssl with 81 | | true -> handle_startup_login_and_close reader writer 82 | | false -> 83 | Async_ssl.Tls.wrap_server_connection 84 | (ssl_server_conf ()) 85 | reader 86 | writer 87 | ~f:(fun _conn -> handle_startup_login_and_close)) 88 | | 'N' -> handle_startup_login_and_close reader writer 89 | | _ -> 90 | print_endline "closing writer"; 91 | Writer.close writer) 92 | | 41 -> 93 | print_endline "skipped ssl"; 94 | handle_startup_login_and_close reader writer 95 | | unexpected_bytes -> 96 | [%string "Unexpected initial message length %{unexpected_bytes#Int}"] |> failwith 97 | ;; 98 | 99 | let ssl_harness () = 100 | let%bind () = 101 | (* Postgres refuses to start if ssl files are too broadly permissioned *) 102 | Deferred.List.iter ~how:`Sequential [ crt_file; key_file ] ~f:(fun file -> 103 | Unix.chmod file ~perm:0o600) 104 | in 105 | let%bind dir = Unix.getcwd () in 106 | let harness = 107 | (* this is a blocking operation *) 108 | Harness.create 109 | ~extra_server_args: 110 | [ "-c" 111 | ; [%string "ssl_cert_file=%{dir}/%{crt_file}"] 112 | ; "-c" 113 | ; [%string "ssl_key_file=%{dir}/%{key_file}"] 114 | ; "-c" 115 | ; "ssl=on" 116 | ] 117 | () 118 | in 119 | return harness 120 | ;; 121 | 122 | let connect_and_close ~ssl_mode where_to_connect = 123 | let%bind conn = 124 | Postgres_async.connect 125 | ~ssl_mode 126 | ~server:where_to_connect 127 | ~user:"postgres" 128 | ~database:"postgres" 129 | () 130 | in 131 | match conn with 132 | | Error error -> 133 | Expect_test_helpers_core.print_s ~hide_positions:true [%message (error : Error.t)]; 134 | Deferred.unit 135 | | Ok conn -> 136 | printf "Connected\n"; 137 | let%bind close_finished = Postgres_async.close conn in 138 | Or_error.ok_exn close_finished; 139 | return () 140 | ;; 141 | 142 | let%expect_test "SSL negotation failure does not raise" = 143 | let%bind () = 144 | with_manual_server 145 | ~handle_client:(handle_client ~force_non_ssl:true ~ssl_char:'S') 146 | ~f:(connect_and_close ~ssl_mode:Prefer) 147 | in 148 | [%expect 149 | {| 150 | (error ( 151 | monitor.ml.Error 152 | (Ssl_error 153 | ("error:1408F10B:SSL routines:ssl3_get_record:wrong version number") 154 | lib/async_ssl/src/ssl.ml:LINE:COL) 155 | ("" "Caught by monitor ssl_pipe"))) 156 | |}]; 157 | Deferred.unit 158 | ;; 159 | 160 | let%expect_test "SSL negotiation: Do not use SSL" = 161 | let%bind () = 162 | with_manual_server 163 | ~handle_client:(handle_client ~ssl_char:'N') 164 | ~f:(connect_and_close ~ssl_mode:Prefer) 165 | in 166 | [%expect {| Connected |}]; 167 | Deferred.unit 168 | ;; 169 | 170 | let%expect_test "Do not use SSL or connect if server returns unknown char response" = 171 | let%bind () = 172 | with_manual_server 173 | ~handle_client:(handle_client ~ssl_char:'E') 174 | ~f:(connect_and_close ~ssl_mode:Prefer) 175 | in 176 | [%expect 177 | {| 178 | closing writer 179 | (error ( 180 | "Postgres Server indicated it does not understand the SSLRequest message. This may mean that the server is running a very outdated version of postgres, or some other problem may be occurring. You can try to run with ssl_mode = Disable to skip the SSLRequest and use plain TCP." 181 | (response_char E))) 182 | |}]; 183 | Deferred.unit 184 | ;; 185 | 186 | let%expect_test "Use SSL (demonstrates startup,login,and close messages over SSL)" = 187 | let%bind () = 188 | with_manual_server 189 | ~handle_client:(handle_client ~ssl_char:'S') 190 | ~f:(connect_and_close ~ssl_mode:Prefer) 191 | in 192 | [%expect {| Connected |}]; 193 | let%bind () = 194 | with_manual_server 195 | ~handle_client:(handle_client ~ssl_char:'S') 196 | ~f:(connect_and_close ~ssl_mode:Require) 197 | in 198 | [%expect {| Connected |}]; 199 | Deferred.unit 200 | ;; 201 | 202 | let%expect_test "SSL negotiation: Error if SSL is required but not available" = 203 | let%bind () = 204 | with_manual_server 205 | ~handle_client:(handle_client ~ssl_char:'N') 206 | ~f:(connect_and_close ~ssl_mode:Require) 207 | in 208 | [%expect 209 | {| 210 | (error 211 | "Server indicated it cannot use SSL connections, but ssl_mode is set to Require") 212 | |}]; 213 | let%bind () = 214 | with_manual_server 215 | ~handle_client:(handle_client ~ssl_char:'E') 216 | ~f:(connect_and_close ~ssl_mode:Require) 217 | in 218 | [%expect 219 | {| 220 | closing writer 221 | (error ( 222 | "Postgres Server indicated it does not understand the SSLRequest message. This may mean that the server is running a very outdated version of postgres, or some other problem may be occurring. You can try to run with ssl_mode = Disable to skip the SSLRequest and use plain TCP." 223 | (response_char E))) 224 | |}]; 225 | Deferred.unit 226 | ;; 227 | 228 | let%expect_test "sslmode = Disable skips sslrequest" = 229 | let%bind () = 230 | with_manual_server 231 | ~handle_client:(handle_client ~ssl_char:'N') 232 | ~f:(connect_and_close ~ssl_mode:Disable) 233 | in 234 | [%expect 235 | {| 236 | skipped ssl 237 | Connected 238 | |}]; 239 | Deferred.unit 240 | ;; 241 | 242 | let%expect_test "Connect to live postgres" = 243 | let%bind ssl_harness = ssl_harness () in 244 | let where_to_connect harness = 245 | let port = Harness.port harness in 246 | (* SSL connections have to be made over TCP, cannot be made using the Unix Socket *) 247 | Host_and_port.create ~host:"localhost" ~port |> Tcp.Where_to_connect.of_host_and_port 248 | in 249 | let ssl = where_to_connect ssl_harness in 250 | (* Our pg_hba is set up to still allow TCP connections, so Disable works *) 251 | let%bind () = connect_and_close ssl ~ssl_mode:Disable in 252 | [%expect {| Connected |}]; 253 | let%bind () = connect_and_close ssl ~ssl_mode:Prefer in 254 | [%expect {| Connected |}]; 255 | let%bind () = connect_and_close ssl ~ssl_mode:Require in 256 | [%expect {| Connected |}]; 257 | let tcp_harness = Harness.create () in 258 | let tcp = where_to_connect tcp_harness in 259 | let%bind () = connect_and_close tcp ~ssl_mode:Disable in 260 | [%expect {| Connected |}]; 261 | let%bind () = connect_and_close tcp ~ssl_mode:Prefer in 262 | [%expect {| Connected |}]; 263 | let%bind () = connect_and_close tcp ~ssl_mode:Require in 264 | [%expect 265 | {| 266 | (error 267 | "Server indicated it cannot use SSL connections, but ssl_mode is set to Require") 268 | |}]; 269 | return () 270 | ;; 271 | -------------------------------------------------------------------------------- /test/test_ssl.mli: -------------------------------------------------------------------------------- 1 | (*_ Intentionally empty. *) 2 | -------------------------------------------------------------------------------- /test/utils.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open! Async 3 | 4 | (* Most of the error message from postgres is not stable wrt. server version. 5 | This is probably good enough. *) 6 | let rec delete_unstable_bits_of_error : Sexp.t -> Sexp.t = 7 | let is_code_pair : Sexp.t -> bool = function 8 | | List [ Atom "Code"; Atom _ ] -> true 9 | | _ -> false 10 | in 11 | let is_severity_or_code_pair : Sexp.t -> bool = function 12 | | List [ Atom ("Severity" | "Code"); Atom _ ] -> true 13 | | _ -> false 14 | in 15 | function 16 | | Atom _ as x -> x 17 | | List tags when List.exists tags ~f:is_code_pair -> 18 | List (List.filter tags ~f:is_severity_or_code_pair) 19 | | List [ (Atom "Writer error from inner_monitor" as e1); e2; _ ] -> 20 | List [ e1; e2; Atom "" ] 21 | | List list -> List (List.map list ~f:delete_unstable_bits_of_error) 22 | ;; 23 | 24 | let do_an_epoll = 25 | lazy 26 | (let pipe_r, pipe_w = Core_unix.pipe () in 27 | let pipe_r = Fd.create Char pipe_r (Info.of_string "do-an-epoll-pipe-r") in 28 | let pipe_w = Fd.create Char pipe_w (Info.of_string "do-an-epoll-pipe-w") in 29 | let reader = Reader.create pipe_r in 30 | let writer = Writer.create pipe_w in 31 | fun () -> 32 | let%bind () = Scheduler.yield_until_no_jobs_remain () in 33 | Writer.write_char writer 'x'; 34 | let%bind () = 35 | match%bind Reader.read_char reader with 36 | | `Ok 'x' -> return () 37 | | _ -> assert false 38 | in 39 | let%bind () = Scheduler.yield_until_no_jobs_remain () in 40 | return ()) 41 | ;; 42 | 43 | let pg_backend_pid postgres = 44 | let backend_pid = Set_once.create () in 45 | let%bind result = 46 | Postgres_async.query 47 | postgres 48 | "SELECT pg_backend_pid()" 49 | ~handle_row:(fun ~column_names:_ ~values -> 50 | match Iarray.to_array values with 51 | | [| Some p |] -> Set_once.set_exn backend_pid p 52 | | _ -> assert false) 53 | in 54 | Or_error.ok_exn result; 55 | return (Set_once.get_exn backend_pid) 56 | ;; 57 | -------------------------------------------------------------------------------- /test/utils.mli: -------------------------------------------------------------------------------- 1 | open! Core 2 | open! Async 3 | 4 | val delete_unstable_bits_of_error : Sexp.t -> Sexp.t 5 | val do_an_epoll : (unit -> unit Deferred.t) lazy_t 6 | val pg_backend_pid : Postgres_async.t -> string Deferred.t 7 | --------------------------------------------------------------------------------