├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .gitlab-ci.yml ├── Cargo.lock ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── deny.toml ├── examples └── base │ ├── Cargo.toml │ └── src │ └── main.rs ├── watermelon-mini ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md └── src │ ├── lib.rs │ ├── non_standard_zstd.rs │ ├── proto │ ├── authenticator.rs │ ├── connection │ │ ├── compression.rs │ │ ├── mod.rs │ │ └── security.rs │ ├── connector.rs │ └── mod.rs │ └── util.rs ├── watermelon-net ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md └── src │ ├── connection │ ├── mod.rs │ ├── streaming.rs │ └── websocket.rs │ ├── future.rs │ ├── happy_eyeballs.rs │ └── lib.rs ├── watermelon-nkeys ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md └── src │ ├── crc.rs │ ├── lib.rs │ └── seed.rs ├── watermelon-proto ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md └── src │ ├── connect.rs │ ├── headers │ ├── map.rs │ ├── mod.rs │ ├── name.rs │ └── value.rs │ ├── lib.rs │ ├── message.rs │ ├── proto │ ├── client.rs │ ├── decoder │ │ ├── framed.rs │ │ ├── mod.rs │ │ └── stream.rs │ ├── encoder │ │ ├── framed.rs │ │ ├── mod.rs │ │ └── stream.rs │ ├── mod.rs │ └── server.rs │ ├── queue_group.rs │ ├── server_addr.rs │ ├── server_error.rs │ ├── server_info.rs │ ├── status_code.rs │ ├── subject.rs │ ├── subscription_id.rs │ ├── tests.rs │ └── util │ ├── buf_list.rs │ ├── crlf.rs │ ├── lines_iter.rs │ ├── mod.rs │ ├── split_spaces.rs │ └── uint.rs └── watermelon ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md └── src ├── client ├── builder.rs ├── commands │ ├── mod.rs │ ├── publish.rs │ └── request.rs ├── from_env.rs ├── jetstream │ ├── commands │ │ ├── consumer_batch.rs │ │ ├── consumer_list.rs │ │ ├── consumer_stream.rs │ │ ├── mod.rs │ │ └── stream_list.rs │ ├── mod.rs │ └── resources │ │ ├── consumer.rs │ │ ├── mod.rs │ │ └── stream.rs ├── mod.rs ├── quick_info.rs └── tests.rs ├── handler ├── delayed.rs ├── mod.rs └── pinger.rs ├── lib.rs ├── multiplexed_subscription.rs ├── subscription.rs ├── tests.rs └── util ├── atomic.rs ├── future.rs └── mod.rs /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | cargo-deny: 10 | name: cargo-deny 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - uses: EmbarkStudios/cargo-deny-action@v2 16 | 17 | fmt: 18 | name: rustfmt / 1.90.0 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - uses: dtolnay/rust-toolchain@1.90.0 25 | with: 26 | components: rustfmt 27 | 28 | - name: Rust rustfmt 29 | run: cargo fmt --all -- --check 30 | 31 | clippy: 32 | name: clippy / 1.90.0 33 | runs-on: ubuntu-latest 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | 38 | - uses: dtolnay/rust-toolchain@1.90.0 39 | with: 40 | components: clippy 41 | 42 | - name: Run clippy 43 | run: cargo clippy --all-features -- -D warnings 44 | 45 | cargo-hack: 46 | name: cargo-hack / 1.90.0 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v4 50 | 51 | - uses: dtolnay/rust-toolchain@1.90.0 52 | 53 | - uses: taiki-e/install-action@v2 54 | with: 55 | tool: cargo-hack@0.6.37 56 | 57 | - name: Run cargo-hack 58 | run: cargo hack check --feature-powerset --no-dev-deps --at-least-one-of aws-lc-rs,ring --at-least-one-of rand,getrandom 59 | 60 | test: 61 | name: test / ${{ matrix.name }} 62 | runs-on: ubuntu-latest 63 | 64 | strategy: 65 | matrix: 66 | include: 67 | - name: stable 68 | rust: stable 69 | - name: beta 70 | rust: beta 71 | - name: nightly 72 | rust: nightly 73 | - name: 1.85.0 74 | rust: 1.85.0 75 | 76 | steps: 77 | - uses: actions/checkout@v4 78 | 79 | - uses: dtolnay/rust-toolchain@master 80 | with: 81 | toolchain: ${{ matrix.rust }} 82 | 83 | - name: Run tests 84 | run: cargo test 85 | 86 | - name: Run tests (--features websocket,portable-atomic) 87 | run: cargo test --features websocket,portable-atomic 88 | 89 | - name: Run tests (--no-default-features --features ring,getrandom) 90 | run: cargo test --no-default-features --features ring,getrandom 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | 4 | rust:deny: 5 | stage: test 6 | image: rust:1.90-alpine3.22 7 | before_script: 8 | - apk add cargo-deny 9 | script: 10 | - cargo deny check 11 | 12 | rust:fmt: 13 | stage: test 14 | image: rust:1.90-alpine3.22 15 | before_script: 16 | - rustup component add rustfmt 17 | script: 18 | - cargo fmt -- --check 19 | 20 | rust:clippy: 21 | stage: test 22 | image: rust:1.90-alpine3.20 23 | before_script: 24 | - apk add build-base musl-dev linux-headers cmake perl go 25 | - rustup component add clippy 26 | script: 27 | - cargo clippy --all-features -- -D warnings 28 | 29 | rust:hack: 30 | stage: test 31 | image: rust:1.90-alpine3.20 32 | before_script: 33 | - apk add build-base musl-dev linux-headers cmake perl go cargo-hack 34 | script: 35 | - cargo hack check --feature-powerset --no-dev-deps --at-least-one-of aws-lc-rs,ring --at-least-one-of rand,getrandom 36 | 37 | rust:test: 38 | stage: test 39 | image: rust:1.90-alpine3.22 40 | before_script: 41 | - apk add musl-dev cmake perl go 42 | script: 43 | - cargo test 44 | - cargo test --features websocket,portable-atomic 45 | - cargo test --no-default-features --features ring,getrandom 46 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "watermelon", 4 | "watermelon-mini", 5 | "watermelon-net", 6 | "watermelon-proto", 7 | "watermelon-nkeys", 8 | "examples/base", 9 | ] 10 | resolver = "2" 11 | 12 | [workspace.package] 13 | edition = "2024" 14 | license = "MIT OR Apache-2.0" 15 | repository = "https://github.com/M4SS-Code/watermelon" 16 | rust-version = "1.85" 17 | 18 | [workspace.lints.rust] 19 | unsafe_code = "deny" 20 | unreachable_pub = "deny" 21 | 22 | [workspace.lints.clippy] 23 | pedantic = { level = "warn", priority = -1 } 24 | module_name_repetitions = "allow" 25 | await_holding_refcell_ref = "deny" 26 | map_unwrap_or = "warn" 27 | needless_lifetimes = "warn" 28 | needless_raw_string_hashes = "warn" 29 | redundant_closure_for_method_calls = "warn" 30 | semicolon_if_nothing_returned = "warn" 31 | str_to_string = "warn" 32 | clone_on_ref_ptr = "warn" 33 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 M4SS Srl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

watermelon

2 |
3 | 4 | Pure Rust NATS client implementation 5 | 6 |
7 | 8 | `watermelon` is an independent and opinionated implementation of the NATS 9 | client protocol and the NATS client API for Rust. The goal of the project 10 | is to produce a more secure, composable and idiomatic implementation compared 11 | to the official one. 12 | 13 | Most users of this project will depend on the `watermelon` crate directly and on 14 | `watermelon-proto` and `watermelon-nkeys` via the re-exports in `watermelon`. 15 | 16 | Watermelon is divided into multiple crates, all hosted in the same monorepo. 17 | 18 | | Crate name | Crates.io release | Docs | Description | 19 | | ------------------ | --------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- | ------------------------------------------------------------------------- | 20 | | `watermelon` | [![crates.io](https://img.shields.io/crates/v/watermelon.svg)](https://crates.io/crates/watermelon) | [![Docs](https://docs.rs/watermelon/badge.svg)](https://docs.rs/watermelon) | High level actor based NATS Core and NATS Jetstream client implementation | 21 | | `watermelon-mini` | [![crates.io](https://img.shields.io/crates/v/watermelon-mini.svg)](https://crates.io/crates/watermelon-mini) | [![Docs](https://docs.rs/watermelon-mini/badge.svg)](https://docs.rs/watermelon-mini) | Bare bones NATS Core client implementation | 22 | | `watermelon-net` | [![crates.io](https://img.shields.io/crates/v/watermelon-net.svg)](https://crates.io/crates/watermelon-net) | [![Docs](https://docs.rs/watermelon-net/badge.svg)](https://docs.rs/watermelon-net) | Low-level NATS Core network implementation | 23 | | `watermelon-proto` | [![crates.io](https://img.shields.io/crates/v/watermelon-proto.svg)](https://crates.io/crates/watermelon-proto) | [![Docs](https://docs.rs/watermelon-proto/badge.svg)](https://docs.rs/watermelon-proto) | `#[no_std]` NATS Core Sans-IO protocol implementation | 24 | | `watermelon-nkeys` | [![crates.io](https://img.shields.io/crates/v/watermelon-nkeys.svg)](https://crates.io/crates/watermelon-nkeys) | [![Docs](https://docs.rs/watermelon-nkeys/badge.svg)](https://docs.rs/watermelon-nkeys) | Minimal NKeys implementation for NATS client authentication | 25 | 26 | # Philosophy and Design 27 | 28 | 1. **Security by design**: this library uses type-safe and checked APIs, such as `Subject`, to prevent entire classes of errors and security vulnerabilities. 29 | 2. **Layering and composability**: the library is split into layers. You can get a high-level, batteries included implementation via `watermelon`, or depend directly on the lower-level crates for maximum flexibility. 30 | 3. **Opinionated, Rusty take**: we adapt the Go-style API of nats-server and apply different trade-offs to make NATS feel more Rusty. We sacrifice a bit of performance by enabling server verbose mode, and get better errors in return. 31 | 4. **Legacy is in the past**: we only support `nats-server >= 2.10` and avoid legacy versions compatibility code like the STARTTLS-style TLS upgrade path or fallbacks for older JetStream APIs. We also prefer pull consumers over push consumers given the robust flow control, easier compatibility with multi-account environments and stronger permissions handling. 32 | 5. **Permissive licensing**: dual licensed under MIT and APACHE-2.0. 33 | 34 | ## License 35 | 36 | Licensed under either of 37 | 38 | - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) 39 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or ) 40 | 41 | at your option. 42 | 43 | ### Contribution 44 | 45 | Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. 46 | -------------------------------------------------------------------------------- /deny.toml: -------------------------------------------------------------------------------- 1 | [advisories] 2 | ignore = [ 3 | ] 4 | 5 | [licenses] 6 | allow = [ 7 | "MIT", 8 | "Apache-2.0", 9 | "BSD-3-Clause", 10 | "ISC", 11 | "Unicode-3.0", 12 | "0BSD", 13 | "OpenSSL", 14 | "CDLA-Permissive-2.0", 15 | ] 16 | 17 | [licenses.private] 18 | ignore = true 19 | 20 | [bans] 21 | multiple-versions = "warn" 22 | wildcards = "deny" 23 | deny = [ 24 | ] 25 | 26 | [sources] 27 | unknown-registry = "deny" 28 | unknown-git = "deny" 29 | 30 | [sources.allow-org] 31 | #github = ["M4SS-Code"] 32 | -------------------------------------------------------------------------------- /examples/base/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon-base-example" 3 | version = "0.1.0" 4 | edition.workspace = true 5 | license.workspace = true 6 | repository.workspace = true 7 | rust-version.workspace = true 8 | publish = false 9 | 10 | [dependencies] 11 | tokio = { version = "1.44", features = ["macros", "rt-multi-thread", "time", "signal"] } 12 | futures-util = { version = "0.3.31", default-features = false } 13 | watermelon = { path = "../../watermelon", version = "0.4" } 14 | bytes = "1.10.1" 15 | jiff = { version = "0.2.1", default-features = false, features = ["std", "tz-system", "tzdb-zoneinfo"] } 16 | 17 | [lints] 18 | workspace = true 19 | -------------------------------------------------------------------------------- /examples/base/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | hash::{DefaultHasher, Hasher as _}, 3 | time::Duration, 4 | }; 5 | 6 | use bytes::Bytes; 7 | use futures_util::TryStreamExt as _; 8 | use jiff::Zoned; 9 | use tokio::{ 10 | signal::ctrl_c, 11 | task::JoinSet, 12 | time::{MissedTickBehavior, interval}, 13 | }; 14 | use watermelon::{ 15 | core::Client, 16 | proto::{ 17 | Subject, 18 | headers::{HeaderName, HeaderValue}, 19 | }, 20 | }; 21 | 22 | type BoxError = Box; 23 | 24 | #[tokio::main] 25 | async fn main() -> Result<(), BoxError> { 26 | let client = Client::builder() 27 | .connect("nats://demo.nats.io".parse()?) 28 | .await?; 29 | println!("Quick Info: {:?}", client.quick_info()); 30 | 31 | let mut set = JoinSet::new(); 32 | 33 | // Subscribe to `watermelon.>`, print every message we get and reply if possible 34 | set.spawn({ 35 | let client = client.clone(); 36 | 37 | async move { 38 | let mut subscription = client 39 | .subscribe(Subject::from_static("watermelon.>"), None) 40 | .await?; 41 | while let Some(msg) = subscription.try_next().await? { 42 | println!( 43 | "Received new message subject={:?} headers={:?} payload={:?}", 44 | msg.base.subject, msg.base.headers, msg.base.payload 45 | ); 46 | 47 | if let Some(reply_subject) = msg.base.reply_subject { 48 | client 49 | .publish(reply_subject) 50 | .header(HeaderName::from_static("Local-Time"), local_time()) 51 | .payload(Bytes::from_static("Welcome from Watermelon!".as_bytes())) 52 | .await?; 53 | } 54 | } 55 | 56 | Ok::<_, BoxError>(()) 57 | } 58 | }); 59 | 60 | // Publish to `watermelon.[random number]` every 20 seconds and await the response 61 | set.spawn({ 62 | let client = client.clone(); 63 | 64 | async move { 65 | let mut interval = interval(Duration::from_secs(20)); 66 | interval.set_missed_tick_behavior(MissedTickBehavior::Delay); 67 | 68 | loop { 69 | interval.tick().await; 70 | 71 | let subject = format!("watermelon.{}", rng()).try_into()?; 72 | println!("Sending new request..."); 73 | let response_fut = client 74 | .request(subject) 75 | .header(HeaderName::from_static("Local-Time"), local_time()) 76 | .payload(Bytes::from_static(b"Hello from Watermelon!")) 77 | .await?; 78 | println!("Awaiting response..."); 79 | match response_fut.await { 80 | Ok(resp) => { 81 | println!( 82 | "Received response subject={:?} headers={:?} payload={:?}", 83 | resp.base.subject, resp.base.headers, resp.base.payload 84 | ); 85 | } 86 | Err(err) => { 87 | eprintln!("Received error err={err:?}"); 88 | } 89 | } 90 | } 91 | } 92 | }); 93 | 94 | // Wait for the user to CTRL+C the program and gracefully shutdown the client 95 | set.spawn(async move { 96 | ctrl_c().await?; 97 | println!("Starting graceful shutdown..."); 98 | client.close().await; 99 | Ok::<_, BoxError>(()) 100 | }); 101 | 102 | while let Some(next) = set.join_next().await { 103 | println!("Task exited with: {next:?}"); 104 | } 105 | 106 | Ok(()) 107 | } 108 | 109 | /// Get the local time and timezone 110 | fn local_time() -> HeaderValue { 111 | Zoned::now() 112 | .to_string() 113 | .try_into() 114 | .expect("local DateTime can always be encoded into `HeaderValue`") 115 | } 116 | 117 | /// A poor man's RNG 118 | fn rng() -> u64 { 119 | DefaultHasher::new().finish() 120 | } 121 | -------------------------------------------------------------------------------- /watermelon-mini/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon-mini" 3 | version = "0.3.2" 4 | description = "Minimal NATS Core client implementation" 5 | categories = ["api-bindings", "network-programming"] 6 | keywords = ["nats", "client"] 7 | edition.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | rust-version.workspace = true 11 | 12 | [package.metadata.docs.rs] 13 | features = ["websocket", "non-standard-zstd"] 14 | 15 | [dependencies] 16 | tokio = { version = "1.36", features = ["net"] } 17 | tokio-rustls = { version = "0.26", default-features = false } 18 | rustls-platform-verifier = "0.6" 19 | 20 | watermelon-net = { version = "0.2", path = "../watermelon-net", default-features = false } 21 | watermelon-proto = { version = "0.1.3", path = "../watermelon-proto" } 22 | watermelon-nkeys = { version = "0.1", path = "../watermelon-nkeys", default-features = false } 23 | 24 | thiserror = "2" 25 | 26 | # non-standard-zstd 27 | async-compression = { version = "0.4", features = ["tokio"], optional = true } 28 | 29 | [features] 30 | default = ["aws-lc-rs", "rand"] 31 | websocket = ["watermelon-net/websocket"] 32 | aws-lc-rs = ["tokio-rustls/aws-lc-rs", "watermelon-net/aws-lc-rs", "watermelon-nkeys/aws-lc-rs"] 33 | ring = ["tokio-rustls/ring", "watermelon-net/ring", "watermelon-nkeys/ring"] 34 | fips = ["tokio-rustls/fips", "watermelon-net/fips", "watermelon-nkeys/fips"] 35 | rand = ["watermelon-net/rand"] 36 | getrandom = ["watermelon-net/getrandom"] 37 | non-standard-zstd = ["watermelon-net/non-standard-zstd", "watermelon-proto/non-standard-zstd", "dep:async-compression", "async-compression/zstd"] 38 | 39 | [lints] 40 | workspace = true 41 | -------------------------------------------------------------------------------- /watermelon-mini/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /watermelon-mini/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /watermelon-mini/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /watermelon-mini/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_code)] 2 | 3 | use std::sync::Arc; 4 | 5 | use rustls_platform_verifier::Verifier; 6 | use tokio::net::TcpStream; 7 | use tokio_rustls::{ 8 | TlsConnector, 9 | rustls::{self, ClientConfig, crypto::CryptoProvider, version::TLS13}, 10 | }; 11 | use watermelon_net::Connection; 12 | use watermelon_proto::{ServerAddr, ServerInfo}; 13 | 14 | #[cfg(feature = "non-standard-zstd")] 15 | pub use self::non_standard_zstd::ZstdStream; 16 | use self::proto::connect; 17 | pub use self::proto::{ 18 | AuthenticationMethod, ConnectError, ConnectionCompression, ConnectionSecurity, 19 | }; 20 | 21 | #[cfg(feature = "non-standard-zstd")] 22 | pub(crate) mod non_standard_zstd; 23 | mod proto; 24 | mod util; 25 | 26 | #[derive(Debug, Clone)] 27 | #[non_exhaustive] 28 | pub struct ConnectFlags { 29 | pub tcp_nodelay: bool, 30 | pub echo: bool, 31 | #[cfg(feature = "non-standard-zstd")] 32 | pub zstd_compression_level: Option, 33 | } 34 | 35 | impl Default for ConnectFlags { 36 | fn default() -> Self { 37 | Self { 38 | tcp_nodelay: true, 39 | echo: false, 40 | #[cfg(feature = "non-standard-zstd")] 41 | zstd_compression_level: Some(3), 42 | } 43 | } 44 | } 45 | 46 | /// Connect to a given address with some reasonable presets. 47 | /// 48 | /// The function is going to establish a TLS 1.3 connection, without the support of the client 49 | /// authorization. 50 | /// 51 | /// # Errors 52 | /// 53 | /// This returns an error in case the connection fails. 54 | #[expect( 55 | clippy::missing_panics_doc, 56 | reason = "the crypto_provider function always returns a provider that supports TLS 1.3" 57 | )] 58 | pub async fn easy_connect( 59 | addr: &ServerAddr, 60 | auth: Option<&AuthenticationMethod>, 61 | flags: ConnectFlags, 62 | ) -> Result< 63 | ( 64 | Connection< 65 | ConnectionCompression>, 66 | ConnectionSecurity, 67 | >, 68 | Box, 69 | ), 70 | ConnectError, 71 | > { 72 | let provider = Arc::new(crypto_provider()); 73 | let connector = TlsConnector::from(Arc::new( 74 | ClientConfig::builder_with_provider(Arc::clone(&provider)) 75 | .with_protocol_versions(&[&TLS13]) 76 | .unwrap() 77 | .dangerous() 78 | .with_custom_certificate_verifier(Arc::new( 79 | Verifier::new(provider).map_err(ConnectError::Tls)?, 80 | )) 81 | .with_no_client_auth(), 82 | )); 83 | 84 | let (conn, info) = connect(&connector, addr, "watermelon".to_owned(), auth, flags).await?; 85 | Ok((conn, info)) 86 | } 87 | 88 | fn crypto_provider() -> CryptoProvider { 89 | #[cfg(feature = "aws-lc-rs")] 90 | return rustls::crypto::aws_lc_rs::default_provider(); 91 | #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] 92 | return rustls::crypto::ring::default_provider(); 93 | #[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))] 94 | compile_error!("Please enable the `aws-lc-rs` or the `ring` feature") 95 | } 96 | -------------------------------------------------------------------------------- /watermelon-mini/src/non_standard_zstd.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt::{self, Debug, Formatter}, 3 | io, 4 | pin::Pin, 5 | task::{Context, Poll}, 6 | }; 7 | 8 | use async_compression::{ 9 | Level, 10 | tokio::{bufread::ZstdDecoder, write::ZstdEncoder}, 11 | }; 12 | use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf}; 13 | 14 | use crate::util::MaybeConnection; 15 | 16 | pub struct ZstdStream { 17 | decoder: ZstdDecoder>>, 18 | encoder: ZstdEncoder>, 19 | } 20 | 21 | impl ZstdStream 22 | where 23 | S: AsyncRead + AsyncWrite + Unpin, 24 | { 25 | #[must_use] 26 | pub fn new(stream: S, compression_level: u8) -> Self { 27 | Self { 28 | decoder: ZstdDecoder::new(BufReader::new(MaybeConnection(Some(stream)))), 29 | encoder: ZstdEncoder::with_quality( 30 | MaybeConnection(None), 31 | Level::Precise(compression_level.into()), 32 | ), 33 | } 34 | } 35 | } 36 | 37 | impl AsyncRead for ZstdStream 38 | where 39 | S: AsyncRead + AsyncWrite + Unpin, 40 | { 41 | fn poll_read( 42 | mut self: Pin<&mut Self>, 43 | cx: &mut Context<'_>, 44 | buf: &mut ReadBuf<'_>, 45 | ) -> Poll> { 46 | if let Some(stream) = self.encoder.get_mut().0.take() { 47 | self.decoder.get_mut().get_mut().0 = Some(stream); 48 | } 49 | 50 | Pin::new(&mut self.decoder).poll_read(cx, buf) 51 | } 52 | } 53 | 54 | impl AsyncWrite for ZstdStream 55 | where 56 | S: AsyncRead + AsyncWrite + Unpin, 57 | { 58 | fn poll_write( 59 | mut self: Pin<&mut Self>, 60 | cx: &mut Context<'_>, 61 | buf: &[u8], 62 | ) -> Poll> { 63 | if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { 64 | self.encoder.get_mut().0 = Some(stream); 65 | } 66 | 67 | Pin::new(&mut self.encoder).poll_write(cx, buf) 68 | } 69 | 70 | fn poll_write_vectored( 71 | mut self: Pin<&mut Self>, 72 | cx: &mut Context<'_>, 73 | bufs: &[io::IoSlice<'_>], 74 | ) -> Poll> { 75 | if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { 76 | self.encoder.get_mut().0 = Some(stream); 77 | } 78 | 79 | Pin::new(&mut self.encoder).poll_write_vectored(cx, bufs) 80 | } 81 | 82 | fn is_write_vectored(&self) -> bool { 83 | if let Some(stream) = &self.encoder.get_ref().0 { 84 | stream.is_write_vectored() 85 | } else if let Some(stream) = &self.decoder.get_ref().get_ref().0 { 86 | stream.is_write_vectored() 87 | } else { 88 | unreachable!() 89 | } 90 | } 91 | 92 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 93 | if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { 94 | self.encoder.get_mut().0 = Some(stream); 95 | } 96 | 97 | Pin::new(&mut self.encoder).poll_flush(cx) 98 | } 99 | 100 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 101 | if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { 102 | self.encoder.get_mut().0 = Some(stream); 103 | } 104 | 105 | Pin::new(&mut self.encoder).poll_shutdown(cx) 106 | } 107 | } 108 | 109 | impl Debug for ZstdStream { 110 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 111 | f.debug_struct("ZstdStream").finish_non_exhaustive() 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/authenticator.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Debug, Formatter}; 2 | 3 | use watermelon_nkeys::{KeyPair, KeyPairFromSeedError}; 4 | use watermelon_proto::{Connect, ServerAddr, ServerInfo}; 5 | 6 | pub enum AuthenticationMethod { 7 | UserAndPassword { username: String, password: String }, 8 | Creds { jwt: String, nkey: KeyPair }, 9 | } 10 | 11 | #[derive(Debug, thiserror::Error)] 12 | pub enum AuthenticationError { 13 | #[error("missing nonce")] 14 | MissingNonce, 15 | } 16 | 17 | #[derive(Debug, thiserror::Error)] 18 | pub enum CredsParseError { 19 | #[error("contents are truncated")] 20 | Truncated, 21 | #[error("missing closing for JWT")] 22 | MissingJwtClosing, 23 | #[error("missing closing for nkey")] 24 | MissingNkeyClosing, 25 | #[error("missing JWT")] 26 | MissingJwt, 27 | #[error("missing nkey")] 28 | MissingNkey, 29 | #[error("invalid nkey")] 30 | InvalidKey(#[source] KeyPairFromSeedError), 31 | } 32 | 33 | impl AuthenticationMethod { 34 | pub(crate) fn try_from_addr(addr: &ServerAddr) -> Option { 35 | if let (Some(username), Some(password)) = (addr.username(), addr.password()) { 36 | Some(Self::UserAndPassword { 37 | username: username.to_owned(), 38 | password: password.to_owned(), 39 | }) 40 | } else { 41 | None 42 | } 43 | } 44 | 45 | pub(crate) fn prepare_for_auth( 46 | &self, 47 | info: &ServerInfo, 48 | connect: &mut Connect, 49 | ) -> Result<(), AuthenticationError> { 50 | match self { 51 | Self::UserAndPassword { username, password } => { 52 | connect.username = Some(username.clone()); 53 | connect.password = Some(password.clone()); 54 | } 55 | Self::Creds { jwt, nkey } => { 56 | let nonce = info 57 | .nonce 58 | .as_deref() 59 | .ok_or(AuthenticationError::MissingNonce)?; 60 | let signature = nkey.sign(nonce.as_bytes()).to_string(); 61 | 62 | connect.jwt = Some(jwt.clone()); 63 | connect.nkey = Some(nkey.public_key().to_string()); 64 | connect.signature = Some(signature); 65 | } 66 | } 67 | 68 | Ok(()) 69 | } 70 | 71 | /// Creates an `AuthenticationMethod` from the content of a credentials file. 72 | /// 73 | /// # Errors 74 | /// 75 | /// It returns an error if the content is not valid. 76 | pub fn from_creds(contents: &str) -> Result { 77 | let mut jtw = None; 78 | let mut secret = None; 79 | 80 | let mut lines = contents.lines(); 81 | while let Some(line) = lines.next() { 82 | if line == "-----BEGIN NATS USER JWT-----" { 83 | jtw = Some(lines.next().ok_or(CredsParseError::Truncated)?); 84 | 85 | let line = lines.next().ok_or(CredsParseError::Truncated)?; 86 | if line != "------END NATS USER JWT------" { 87 | return Err(CredsParseError::MissingJwtClosing); 88 | } 89 | } else if line == "-----BEGIN USER NKEY SEED-----" { 90 | secret = Some(lines.next().ok_or(CredsParseError::Truncated)?); 91 | 92 | let line = lines.next().ok_or(CredsParseError::Truncated)?; 93 | if line != "------END USER NKEY SEED------" { 94 | return Err(CredsParseError::MissingNkeyClosing); 95 | } 96 | } 97 | } 98 | 99 | let jtw = jtw.ok_or(CredsParseError::MissingJwt)?; 100 | let nkey = secret.ok_or(CredsParseError::MissingNkey)?; 101 | let nkey = KeyPair::from_encoded_seed(nkey).map_err(CredsParseError::InvalidKey)?; 102 | 103 | Ok(Self::Creds { 104 | jwt: jtw.to_owned(), 105 | nkey, 106 | }) 107 | } 108 | } 109 | 110 | impl Debug for AuthenticationMethod { 111 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 112 | f.debug_struct("AuthenticationMethod") 113 | .finish_non_exhaustive() 114 | } 115 | } 116 | 117 | #[cfg(test)] 118 | mod tests { 119 | use super::AuthenticationMethod; 120 | 121 | #[test] 122 | fn parse_creds() { 123 | let creds = r"-----BEGIN NATS USER JWT----- 124 | eyJ0eXAiOiJqd3QiLCJhbGciOiJlZDI1NTE5In0.eyJqdGkiOiJUVlNNTEtTWkJBN01VWDNYQUxNUVQzTjRISUw1UkZGQU9YNUtaUFhEU0oyWlAzNkVMNVJBIiwiaWF0IjoxNTU4MDQ1NTYyLCJpc3MiOiJBQlZTQk0zVTQ1REdZRVVFQ0tYUVM3QkVOSFdHN0tGUVVEUlRFSEFKQVNPUlBWV0JaNEhPSUtDSCIsIm5hbWUiOiJvbWVnYSIsInN1YiI6IlVEWEIyVk1MWFBBU0FKN1pEVEtZTlE3UU9DRldTR0I0Rk9NWVFRMjVIUVdTQUY3WlFKRUJTUVNXIiwidHlwZSI6InVzZXIiLCJuYXRzIjp7InB1YiI6e30sInN1YiI6e319fQ.6TQ2ilCDb6m2ZDiJuj_D_OePGXFyN3Ap2DEm3ipcU5AhrWrNvneJryWrpgi_yuVWKo1UoD5s8bxlmwypWVGFAA 125 | ------END NATS USER JWT------ 126 | 127 | ************************* IMPORTANT ************************* 128 | NKEY Seed printed below can be used to sign and prove identity. 129 | NKEYs are sensitive and should be treated as secrets. 130 | 131 | -----BEGIN USER NKEY SEED----- 132 | SUAOY5JZ2WJKVR4UO2KJ2P3SW6FZFNWEOIMAXF4WZEUNVQXXUOKGM55CYE 133 | ------END USER NKEY SEED------ 134 | 135 | *************************************************************"; 136 | 137 | let AuthenticationMethod::Creds { jwt, nkey } = 138 | AuthenticationMethod::from_creds(creds).unwrap() 139 | else { 140 | panic!("invalid auth method"); 141 | }; 142 | assert_eq!( 143 | jwt, 144 | "eyJ0eXAiOiJqd3QiLCJhbGciOiJlZDI1NTE5In0.eyJqdGkiOiJUVlNNTEtTWkJBN01VWDNYQUxNUVQzTjRISUw1UkZGQU9YNUtaUFhEU0oyWlAzNkVMNVJBIiwiaWF0IjoxNTU4MDQ1NTYyLCJpc3MiOiJBQlZTQk0zVTQ1REdZRVVFQ0tYUVM3QkVOSFdHN0tGUVVEUlRFSEFKQVNPUlBWV0JaNEhPSUtDSCIsIm5hbWUiOiJvbWVnYSIsInN1YiI6IlVEWEIyVk1MWFBBU0FKN1pEVEtZTlE3UU9DRldTR0I0Rk9NWVFRMjVIUVdTQUY3WlFKRUJTUVNXIiwidHlwZSI6InVzZXIiLCJuYXRzIjp7InB1YiI6e30sInN1YiI6e319fQ.6TQ2ilCDb6m2ZDiJuj_D_OePGXFyN3Ap2DEm3ipcU5AhrWrNvneJryWrpgi_yuVWKo1UoD5s8bxlmwypWVGFAA" 145 | ); 146 | assert_eq!( 147 | nkey.public_key().to_string(), 148 | "SAAO4HKVRO54CIBH7EONLBWD6BYIW2IYHQVZTCCDLU6C2IAX7GBEQGJDYE" 149 | ); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/connection/compression.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 | 9 | #[cfg(feature = "non-standard-zstd")] 10 | use crate::non_standard_zstd::ZstdStream; 11 | 12 | #[derive(Debug)] 13 | pub enum ConnectionCompression { 14 | Plain(S), 15 | #[cfg(feature = "non-standard-zstd")] 16 | Zstd(ZstdStream), 17 | } 18 | 19 | impl ConnectionCompression 20 | where 21 | S: AsyncRead + AsyncWrite + Unpin, 22 | { 23 | #[cfg(feature = "non-standard-zstd")] 24 | pub(crate) fn upgrade_zstd(self, compression_level: u8) -> Self { 25 | let Self::Plain(socket) = self else { 26 | unreachable!() 27 | }; 28 | 29 | Self::Zstd(ZstdStream::new(socket, compression_level)) 30 | } 31 | 32 | #[cfg(feature = "non-standard-zstd")] 33 | pub fn is_zstd_compressed(&self) -> bool { 34 | matches!(self, Self::Zstd(_)) 35 | } 36 | } 37 | 38 | impl AsyncRead for ConnectionCompression 39 | where 40 | S: AsyncRead + AsyncWrite + Unpin, 41 | { 42 | fn poll_read( 43 | self: Pin<&mut Self>, 44 | cx: &mut Context<'_>, 45 | buf: &mut ReadBuf<'_>, 46 | ) -> Poll> { 47 | match self.get_mut() { 48 | Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf), 49 | #[cfg(feature = "non-standard-zstd")] 50 | Self::Zstd(conn) => Pin::new(conn).poll_read(cx, buf), 51 | } 52 | } 53 | } 54 | 55 | impl AsyncWrite for ConnectionCompression 56 | where 57 | S: AsyncRead + AsyncWrite + Unpin, 58 | { 59 | fn poll_write( 60 | self: Pin<&mut Self>, 61 | cx: &mut Context<'_>, 62 | buf: &[u8], 63 | ) -> Poll> { 64 | match self.get_mut() { 65 | Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf), 66 | #[cfg(feature = "non-standard-zstd")] 67 | Self::Zstd(conn) => Pin::new(conn).poll_write(cx, buf), 68 | } 69 | } 70 | 71 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 72 | match self.get_mut() { 73 | Self::Plain(conn) => Pin::new(conn).poll_flush(cx), 74 | #[cfg(feature = "non-standard-zstd")] 75 | Self::Zstd(conn) => Pin::new(conn).poll_flush(cx), 76 | } 77 | } 78 | 79 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 80 | match self.get_mut() { 81 | Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx), 82 | #[cfg(feature = "non-standard-zstd")] 83 | Self::Zstd(conn) => Pin::new(conn).poll_shutdown(cx), 84 | } 85 | } 86 | 87 | fn poll_write_vectored( 88 | self: Pin<&mut Self>, 89 | cx: &mut Context<'_>, 90 | bufs: &[io::IoSlice<'_>], 91 | ) -> Poll> { 92 | match self.get_mut() { 93 | Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), 94 | #[cfg(feature = "non-standard-zstd")] 95 | Self::Zstd(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), 96 | } 97 | } 98 | 99 | fn is_write_vectored(&self) -> bool { 100 | match self { 101 | Self::Plain(conn) => conn.is_write_vectored(), 102 | #[cfg(feature = "non-standard-zstd")] 103 | Self::Zstd(conn) => conn.is_write_vectored(), 104 | } 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/connection/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::compression::ConnectionCompression; 2 | pub use self::security::ConnectionSecurity; 3 | 4 | mod compression; 5 | mod security; 6 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/connection/security.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 | use tokio_rustls::{TlsConnector, client::TlsStream, rustls::pki_types::ServerName}; 9 | 10 | #[derive(Debug)] 11 | #[expect( 12 | clippy::large_enum_variant, 13 | reason = "using TLS is the recommended thing, we do not want to affect it" 14 | )] 15 | pub enum ConnectionSecurity { 16 | Plain(S), 17 | Tls(TlsStream), 18 | } 19 | 20 | impl ConnectionSecurity 21 | where 22 | S: AsyncRead + AsyncWrite + Unpin, 23 | { 24 | pub(crate) async fn upgrade_tls( 25 | self, 26 | connector: &TlsConnector, 27 | domain: ServerName<'static>, 28 | ) -> io::Result { 29 | let conn = match self { 30 | Self::Plain(conn) => conn, 31 | Self::Tls(_) => unreachable!("trying to upgrade to Tls a Tls connection"), 32 | }; 33 | 34 | let conn = connector.connect(domain, conn).await?; 35 | Ok(Self::Tls(conn)) 36 | } 37 | } 38 | 39 | impl AsyncRead for ConnectionSecurity 40 | where 41 | S: AsyncRead + AsyncWrite + Unpin, 42 | { 43 | fn poll_read( 44 | self: Pin<&mut Self>, 45 | cx: &mut Context<'_>, 46 | buf: &mut ReadBuf<'_>, 47 | ) -> Poll> { 48 | match self.get_mut() { 49 | Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf), 50 | Self::Tls(conn) => Pin::new(conn).poll_read(cx, buf), 51 | } 52 | } 53 | } 54 | 55 | impl AsyncWrite for ConnectionSecurity 56 | where 57 | S: AsyncRead + AsyncWrite + Unpin, 58 | { 59 | fn poll_write( 60 | self: Pin<&mut Self>, 61 | cx: &mut Context<'_>, 62 | buf: &[u8], 63 | ) -> Poll> { 64 | match self.get_mut() { 65 | Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf), 66 | Self::Tls(conn) => Pin::new(conn).poll_write(cx, buf), 67 | } 68 | } 69 | 70 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 71 | match self.get_mut() { 72 | Self::Plain(conn) => Pin::new(conn).poll_flush(cx), 73 | Self::Tls(conn) => Pin::new(conn).poll_flush(cx), 74 | } 75 | } 76 | 77 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 78 | match self.get_mut() { 79 | Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx), 80 | Self::Tls(conn) => Pin::new(conn).poll_shutdown(cx), 81 | } 82 | } 83 | 84 | fn poll_write_vectored( 85 | self: Pin<&mut Self>, 86 | cx: &mut Context<'_>, 87 | bufs: &[io::IoSlice<'_>], 88 | ) -> Poll> { 89 | match self.get_mut() { 90 | Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), 91 | Self::Tls(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), 92 | } 93 | } 94 | 95 | fn is_write_vectored(&self) -> bool { 96 | match self { 97 | Self::Plain(conn) => conn.is_write_vectored(), 98 | Self::Tls(conn) => conn.is_write_vectored(), 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/connector.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use tokio::net::TcpStream; 4 | use tokio_rustls::{ 5 | TlsConnector, 6 | rustls::{ 7 | self, 8 | pki_types::{InvalidDnsNameError, ServerName}, 9 | }, 10 | }; 11 | use watermelon_net::{ 12 | Connection, StreamingConnection, connect_tcp, 13 | error::{ConnectionReadError, StreamingReadError}, 14 | proto_connect, 15 | }; 16 | #[cfg(feature = "websocket")] 17 | use watermelon_net::{WebsocketConnection, error::WebsocketReadError}; 18 | #[cfg(feature = "websocket")] 19 | use watermelon_proto::proto::error::FrameDecoderError; 20 | use watermelon_proto::{ 21 | Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport, 22 | proto::{ServerOp, error::DecoderError}, 23 | }; 24 | 25 | use crate::{ConnectFlags, ConnectionCompression, util::MaybeConnection}; 26 | 27 | use super::{ 28 | authenticator::{AuthenticationError, AuthenticationMethod}, 29 | connection::ConnectionSecurity, 30 | }; 31 | 32 | #[derive(Debug, thiserror::Error)] 33 | pub enum ConnectError { 34 | #[error("io error")] 35 | Io(#[source] io::Error), 36 | #[error("TLS error")] 37 | Tls(rustls::Error), 38 | #[error("invalid DNS name")] 39 | InvalidDnsName(#[source] InvalidDnsNameError), 40 | #[error("websocket not supported")] 41 | WebsocketUnsupported, 42 | #[error("unexpected ServerOp")] 43 | UnexpectedServerOp, 44 | #[error("decoder error")] 45 | Decoder(#[source] DecoderError), 46 | #[error("authentication error")] 47 | Authentication(#[source] AuthenticationError), 48 | #[error("connect")] 49 | Connect(#[source] watermelon_net::error::ConnectError), 50 | } 51 | 52 | #[expect(clippy::too_many_lines)] 53 | pub(crate) async fn connect( 54 | connector: &TlsConnector, 55 | addr: &ServerAddr, 56 | client_name: String, 57 | auth_method: Option<&AuthenticationMethod>, 58 | flags: ConnectFlags, 59 | ) -> Result< 60 | ( 61 | Connection< 62 | ConnectionCompression>, 63 | ConnectionSecurity, 64 | >, 65 | Box, 66 | ), 67 | ConnectError, 68 | > { 69 | let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?; 70 | conn.set_nodelay(flags.tcp_nodelay) 71 | .map_err(ConnectError::Io)?; 72 | let mut conn = ConnectionSecurity::Plain(conn); 73 | 74 | if matches!(addr.protocol(), Protocol::TLS) { 75 | let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?; 76 | conn = conn 77 | .upgrade_tls(connector, domain.to_owned()) 78 | .await 79 | .map_err(ConnectError::Io)?; 80 | } 81 | 82 | let mut conn = match addr.transport() { 83 | Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)), 84 | #[cfg(feature = "websocket")] 85 | Transport::Websocket => { 86 | let uri = addr.to_string().parse().unwrap(); 87 | Connection::Websocket( 88 | WebsocketConnection::new(uri, conn) 89 | .await 90 | .map_err(ConnectError::Io)?, 91 | ) 92 | } 93 | #[cfg(not(feature = "websocket"))] 94 | Transport::Websocket => return Err(ConnectError::WebsocketUnsupported), 95 | }; 96 | let info = match conn.read_next().await { 97 | Ok(ServerOp::Info { info }) => info, 98 | Ok(_) => return Err(ConnectError::UnexpectedServerOp), 99 | Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => { 100 | return Err(ConnectError::Io(err)); 101 | } 102 | Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => { 103 | return Err(ConnectError::Decoder(err)); 104 | } 105 | #[cfg(feature = "websocket")] 106 | Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => { 107 | return Err(ConnectError::Io(err)); 108 | } 109 | #[cfg(feature = "websocket")] 110 | Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( 111 | FrameDecoderError::Decoder(err), 112 | ))) => return Err(ConnectError::Decoder(err)), 113 | #[cfg(feature = "websocket")] 114 | Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( 115 | FrameDecoderError::IncompleteFrame, 116 | ))) => todo!(), 117 | #[cfg(feature = "websocket")] 118 | Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(), 119 | }; 120 | 121 | let conn = match conn { 122 | Connection::Streaming(streaming) => Connection::Streaming( 123 | if matches!( 124 | (addr.protocol(), info.tls_required), 125 | (Protocol::PossiblyPlain, true) 126 | ) { 127 | let domain = 128 | rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?; 129 | StreamingConnection::new( 130 | streaming 131 | .into_inner() 132 | .upgrade_tls(connector, domain.to_owned()) 133 | .await 134 | .map_err(ConnectError::Io)?, 135 | ) 136 | } else { 137 | streaming 138 | }, 139 | ), 140 | Connection::Websocket(websocket) => Connection::Websocket(websocket), 141 | }; 142 | 143 | let auth; 144 | let auth_method = if let Some(auth_method) = auth_method { 145 | Some(auth_method) 146 | } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) { 147 | auth = auth_method; 148 | Some(&auth) 149 | } else { 150 | None 151 | }; 152 | 153 | #[allow(unused_mut)] 154 | let mut non_standard = NonStandardConnect::default(); 155 | #[cfg(feature = "non-standard-zstd")] 156 | if matches!(conn, Connection::Streaming(_)) { 157 | non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd; 158 | } 159 | 160 | let mut connect = Connect { 161 | verbose: true, 162 | pedantic: false, 163 | require_tls: false, 164 | auth_token: None, 165 | username: None, 166 | password: None, 167 | client_name: Some(client_name), 168 | client_lang: "rust-watermelon", 169 | client_version: env!("CARGO_PKG_VERSION"), 170 | protocol: 1, 171 | echo: flags.echo, 172 | signature: None, 173 | jwt: None, 174 | supports_no_responders: true, 175 | supports_headers: true, 176 | nkey: None, 177 | non_standard, 178 | }; 179 | if let Some(auth_method) = auth_method { 180 | auth_method 181 | .prepare_for_auth(&info, &mut connect) 182 | .map_err(ConnectError::Authentication)?; 183 | } 184 | 185 | let mut conn = match conn { 186 | Connection::Streaming(streaming) => { 187 | Connection::Streaming(streaming.replace_socket(|stream| { 188 | MaybeConnection(Some(ConnectionCompression::Plain(stream))) 189 | })) 190 | } 191 | Connection::Websocket(websocket) => Connection::Websocket(websocket), 192 | }; 193 | 194 | #[cfg(feature = "non-standard-zstd")] 195 | let zstd = connect.non_standard.zstd; 196 | 197 | proto_connect(&mut conn, connect, |conn| { 198 | #[cfg(feature = "non-standard-zstd")] 199 | match conn { 200 | Connection::Streaming(streaming) => { 201 | if zstd { 202 | if let Some(zstd_compression_level) = flags.zstd_compression_level { 203 | let stream = streaming.socket_mut().0.take().unwrap(); 204 | streaming.socket_mut().0 = 205 | Some(stream.upgrade_zstd(zstd_compression_level)); 206 | } 207 | } 208 | } 209 | Connection::Websocket(_websocket) => {} 210 | } 211 | 212 | let _ = conn; 213 | }) 214 | .await 215 | .map_err(ConnectError::Connect)?; 216 | 217 | let conn = match conn { 218 | Connection::Streaming(streaming) => { 219 | Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap())) 220 | } 221 | Connection::Websocket(websocket) => Connection::Websocket(websocket), 222 | }; 223 | 224 | Ok((conn, info)) 225 | } 226 | 227 | fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result, InvalidDnsNameError> { 228 | match addr.host() { 229 | Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())), 230 | Host::Dns(name) => <_ as AsRef>::as_ref(name).try_into(), 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /watermelon-mini/src/proto/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::authenticator::AuthenticationMethod; 2 | pub use self::connection::{ConnectionCompression, ConnectionSecurity}; 3 | pub use self::connector::ConnectError; 4 | pub(crate) use self::connector::connect; 5 | 6 | mod authenticator; 7 | mod connection; 8 | mod connector; 9 | -------------------------------------------------------------------------------- /watermelon-mini/src/util.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 | 9 | #[derive(Debug)] 10 | pub(crate) struct MaybeConnection(pub(crate) Option); 11 | 12 | impl AsyncRead for MaybeConnection 13 | where 14 | S: AsyncRead + Unpin, 15 | { 16 | fn poll_read( 17 | mut self: Pin<&mut Self>, 18 | cx: &mut Context<'_>, 19 | buf: &mut ReadBuf<'_>, 20 | ) -> Poll> { 21 | Pin::new(self.0.as_mut().unwrap()).poll_read(cx, buf) 22 | } 23 | } 24 | 25 | impl AsyncWrite for MaybeConnection 26 | where 27 | S: AsyncWrite + Unpin, 28 | { 29 | fn poll_write( 30 | mut self: Pin<&mut Self>, 31 | cx: &mut Context<'_>, 32 | buf: &[u8], 33 | ) -> Poll> { 34 | Pin::new(self.0.as_mut().unwrap()).poll_write(cx, buf) 35 | } 36 | 37 | fn poll_write_vectored( 38 | mut self: Pin<&mut Self>, 39 | cx: &mut Context<'_>, 40 | bufs: &[io::IoSlice<'_>], 41 | ) -> Poll> { 42 | Pin::new(self.0.as_mut().unwrap()).poll_write_vectored(cx, bufs) 43 | } 44 | 45 | fn is_write_vectored(&self) -> bool { 46 | self.0.as_ref().unwrap().is_write_vectored() 47 | } 48 | 49 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 50 | Pin::new(self.0.as_mut().unwrap()).poll_flush(cx) 51 | } 52 | 53 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 54 | Pin::new(self.0.as_mut().unwrap()).poll_shutdown(cx) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /watermelon-net/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon-net" 3 | version = "0.2.3" 4 | description = "Low-level NATS Core network implementation" 5 | categories = ["api-bindings", "network-programming"] 6 | keywords = ["nats", "client"] 7 | edition.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | rust-version.workspace = true 11 | 12 | [package.metadata.docs.rs] 13 | features = ["websocket", "non-standard-zstd"] 14 | 15 | [dependencies] 16 | tokio = { version = "1.36", features = ["rt", "net", "time", "io-util"] } 17 | futures-core = "0.3.14" 18 | bytes = "1" 19 | 20 | tokio-websockets = { version = "0.12", features = ["client"], optional = true } 21 | futures-sink = { version = "0.3.14", default-features = false, optional = true } 22 | http = { version = "1", optional = true } 23 | 24 | watermelon-proto = { version = "0.1.3", path = "../watermelon-proto" } 25 | 26 | thiserror = "2" 27 | pin-project-lite = "0.2.15" 28 | 29 | [dev-dependencies] 30 | tokio = { version = "1", features = ["macros"] } 31 | futures-util = { version = "0.3.14", default-features = false } 32 | claims = "0.8" 33 | 34 | [features] 35 | default = ["aws-lc-rs", "rand"] 36 | websocket = ["dep:tokio-websockets", "dep:futures-sink", "dep:http"] 37 | ring = ["tokio-websockets?/ring"] 38 | aws-lc-rs = ["tokio-websockets?/aws_lc_rs"] 39 | fips = [] 40 | rand = ["tokio-websockets?/rand"] 41 | getrandom = ["tokio-websockets?/getrandom"] 42 | non-standard-zstd = ["watermelon-proto/non-standard-zstd"] 43 | 44 | [lints] 45 | workspace = true 46 | -------------------------------------------------------------------------------- /watermelon-net/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /watermelon-net/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /watermelon-net/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /watermelon-net/src/connection/streaming.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::{self, Future}, 3 | io, 4 | pin::{Pin, pin}, 5 | task::{Context, Poll}, 6 | }; 7 | 8 | use bytes::Buf; 9 | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; 10 | use watermelon_proto::proto::{ 11 | ClientOp, ServerOp, StreamDecoder, StreamEncoder, error::DecoderError, 12 | }; 13 | 14 | #[derive(Debug)] 15 | pub struct StreamingConnection { 16 | socket: S, 17 | encoder: StreamEncoder, 18 | decoder: StreamDecoder, 19 | may_flush: bool, 20 | } 21 | 22 | impl StreamingConnection 23 | where 24 | S: AsyncRead + AsyncWrite + Unpin, 25 | { 26 | #[must_use] 27 | pub fn new(socket: S) -> Self { 28 | Self { 29 | socket, 30 | encoder: StreamEncoder::new(), 31 | decoder: StreamDecoder::new(), 32 | may_flush: false, 33 | } 34 | } 35 | 36 | pub fn poll_read_next( 37 | &mut self, 38 | cx: &mut Context<'_>, 39 | ) -> Poll> { 40 | loop { 41 | match self.decoder.decode() { 42 | Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)), 43 | Ok(None) => {} 44 | Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))), 45 | } 46 | 47 | let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf())); 48 | match read_buf_fut.poll(cx) { 49 | Poll::Pending => return Poll::Pending, 50 | Poll::Ready(Ok(1..)) => {} 51 | Poll::Ready(Ok(0)) => { 52 | return Poll::Ready(Err(StreamingReadError::Io( 53 | io::ErrorKind::UnexpectedEof.into(), 54 | ))); 55 | } 56 | Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))), 57 | } 58 | } 59 | } 60 | 61 | /// Reads the next [`ServerOp`]. 62 | /// 63 | /// # Errors 64 | /// 65 | /// It returns an error if the content cannot be decoded or if an I/O error occurs. 66 | pub async fn read_next(&mut self) -> Result { 67 | future::poll_fn(|cx| self.poll_read_next(cx)).await 68 | } 69 | 70 | pub fn may_write(&self) -> bool { 71 | self.encoder.has_remaining() 72 | } 73 | 74 | pub fn may_flush(&self) -> bool { 75 | self.may_flush 76 | } 77 | 78 | pub fn may_enqueue_more_ops(&self) -> bool { 79 | self.encoder.remaining() < 8_290_304 80 | } 81 | 82 | pub fn enqueue_write_op(&mut self, item: &ClientOp) { 83 | self.encoder.enqueue_write_op(item); 84 | } 85 | 86 | pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll> { 87 | let remaining = self.encoder.remaining(); 88 | if remaining == 0 { 89 | return Poll::Ready(Ok(0)); 90 | } 91 | 92 | let chunk = self.encoder.chunk(); 93 | let write_outcome = if chunk.len() < remaining && self.socket.is_write_vectored() { 94 | let mut bufs = [io::IoSlice::new(&[]); 64]; 95 | let n = self.encoder.chunks_vectored(&mut bufs); 96 | debug_assert!( 97 | n >= 2, 98 | "perf: chunks_vectored yielded less than 2 chunks despite the apparently fragmented internal encoder representation" 99 | ); 100 | 101 | Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n]) 102 | } else { 103 | debug_assert!( 104 | !chunk.is_empty(), 105 | "perf: chunk shouldn't be empty given that `remaining > 0`" 106 | ); 107 | Pin::new(&mut self.socket).poll_write(cx, chunk) 108 | }; 109 | 110 | match write_outcome { 111 | Poll::Pending => { 112 | self.may_flush = false; 113 | Poll::Pending 114 | } 115 | Poll::Ready(Ok(n)) => { 116 | self.encoder.advance(n); 117 | self.may_flush = true; 118 | Poll::Ready(Ok(n)) 119 | } 120 | Poll::Ready(Err(err)) => Poll::Ready(Err(err)), 121 | } 122 | } 123 | 124 | /// Writes the next chunk of data to the socket. 125 | /// 126 | /// It returns the number of bytes that have been written. 127 | /// 128 | /// # Errors 129 | /// 130 | /// An I/O error is returned if it is not possible to write to the socket. 131 | pub async fn write_next(&mut self) -> io::Result { 132 | future::poll_fn(|cx| self.poll_write_next(cx)).await 133 | } 134 | 135 | pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { 136 | match Pin::new(&mut self.socket).poll_flush(cx) { 137 | Poll::Pending => Poll::Pending, 138 | Poll::Ready(Ok(())) => { 139 | self.may_flush = false; 140 | Poll::Ready(Ok(())) 141 | } 142 | Poll::Ready(Err(err)) => Poll::Ready(Err(err)), 143 | } 144 | } 145 | 146 | /// Flush any buffered writes to the connection 147 | /// 148 | /// # Errors 149 | /// 150 | /// Returns an error if flushing fails 151 | pub async fn flush(&mut self) -> io::Result<()> { 152 | future::poll_fn(|cx| self.poll_flush(cx)).await 153 | } 154 | 155 | /// Shutdown the connection 156 | /// 157 | /// # Errors 158 | /// 159 | /// Returns an error if shutting down the connection fails. 160 | /// Implementations usually ignore this error. 161 | pub async fn shutdown(&mut self) -> io::Result<()> { 162 | future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await 163 | } 164 | 165 | pub fn socket(&self) -> &S { 166 | &self.socket 167 | } 168 | 169 | pub fn socket_mut(&mut self) -> &mut S { 170 | &mut self.socket 171 | } 172 | 173 | pub fn replace_socket(self, replacer: F) -> StreamingConnection 174 | where 175 | F: FnOnce(S) -> S2, 176 | { 177 | StreamingConnection { 178 | socket: replacer(self.socket), 179 | encoder: self.encoder, 180 | decoder: self.decoder, 181 | may_flush: self.may_flush, 182 | } 183 | } 184 | 185 | pub fn into_inner(self) -> S { 186 | self.socket 187 | } 188 | } 189 | 190 | #[derive(Debug, thiserror::Error)] 191 | pub enum StreamingReadError { 192 | #[error("decoder")] 193 | Decoder(#[source] DecoderError), 194 | #[error("io")] 195 | Io(#[source] io::Error), 196 | } 197 | 198 | #[cfg(test)] 199 | mod tests { 200 | use std::{ 201 | pin::Pin, 202 | task::{Context, Poll}, 203 | }; 204 | 205 | use claims::assert_matches; 206 | use futures_util::task; 207 | use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; 208 | use watermelon_proto::proto::{ClientOp, ServerOp}; 209 | 210 | use super::StreamingConnection; 211 | 212 | #[test] 213 | fn ping_pong() { 214 | let waker = task::noop_waker(); 215 | let mut cx = Context::from_waker(&waker); 216 | 217 | let (socket, mut conn) = io::duplex(1024); 218 | 219 | let mut client = StreamingConnection::new(socket); 220 | 221 | // Initial state is ok 222 | assert!(client.poll_read_next(&mut cx).is_pending()); 223 | assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0))); 224 | 225 | let mut buf = [0; 1024]; 226 | let mut read_buf = ReadBuf::new(&mut buf); 227 | assert!( 228 | Pin::new(&mut conn) 229 | .poll_read(&mut cx, &mut read_buf) 230 | .is_pending() 231 | ); 232 | 233 | // Write PING and verify it was received 234 | client.enqueue_write_op(&ClientOp::Ping); 235 | assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6))); 236 | assert_matches!( 237 | Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf), 238 | Poll::Ready(Ok(())) 239 | ); 240 | assert_eq!(read_buf.filled(), b"PING\r\n"); 241 | 242 | // Receive PONG 243 | assert_matches!( 244 | Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"), 245 | Poll::Ready(Ok(6)) 246 | ); 247 | assert_matches!( 248 | client.poll_read_next(&mut cx), 249 | Poll::Ready(Ok(ServerOp::Pong)) 250 | ); 251 | assert!(client.poll_read_next(&mut cx).is_pending()); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /watermelon-net/src/connection/websocket.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future, io, 3 | pin::Pin, 4 | task::{Context, Poll, Waker}, 5 | }; 6 | 7 | use bytes::Bytes; 8 | use futures_core::Stream as _; 9 | use futures_sink::Sink; 10 | use http::Uri; 11 | use tokio::io::{AsyncRead, AsyncWrite}; 12 | use tokio_websockets::{ClientBuilder, Message, WebSocketStream}; 13 | use watermelon_proto::proto::{ 14 | ClientOp, FramedEncoder, ServerOp, decode_frame, error::FrameDecoderError, 15 | }; 16 | 17 | #[derive(Debug)] 18 | pub struct WebsocketConnection { 19 | socket: WebSocketStream, 20 | encoder: FramedEncoder, 21 | residual_frame: Bytes, 22 | should_flush: bool, 23 | } 24 | 25 | impl WebsocketConnection 26 | where 27 | S: AsyncRead + AsyncWrite + Unpin, 28 | { 29 | /// Construct a websocket stream to a pre-established connection `socket`. 30 | /// 31 | /// # Errors 32 | /// 33 | /// Returns an error if the websocket handshake fails. 34 | pub async fn new(uri: Uri, socket: S) -> io::Result { 35 | let (socket, _resp) = ClientBuilder::from_uri(uri) 36 | .connect_on(socket) 37 | .await 38 | .map_err(websockets_error_to_io)?; 39 | Ok(Self { 40 | socket, 41 | encoder: FramedEncoder::new(), 42 | residual_frame: Bytes::new(), 43 | should_flush: false, 44 | }) 45 | } 46 | 47 | pub fn poll_read_next( 48 | &mut self, 49 | cx: &mut Context<'_>, 50 | ) -> Poll> { 51 | loop { 52 | if !self.residual_frame.is_empty() { 53 | return Poll::Ready( 54 | decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder), 55 | ); 56 | } 57 | 58 | match Pin::new(&mut self.socket).poll_next(cx) { 59 | Poll::Pending => return Poll::Pending, 60 | Poll::Ready(Some(Ok(message))) if message.is_binary() => { 61 | self.residual_frame = message.into_payload().into(); 62 | } 63 | Poll::Ready(Some(Ok(_message))) => {} 64 | Poll::Ready(Some(Err(err))) => { 65 | return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err)))); 66 | } 67 | Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)), 68 | } 69 | } 70 | } 71 | 72 | /// Reads the next [`ServerOp`]. 73 | /// 74 | /// # Errors 75 | /// 76 | /// It returns an error if the content cannot be decoded or if an I/O error occurs. 77 | pub async fn read_next(&mut self) -> Result { 78 | future::poll_fn(|cx| self.poll_read_next(cx)).await 79 | } 80 | 81 | pub fn should_flush(&self) -> bool { 82 | self.should_flush 83 | } 84 | 85 | pub fn may_enqueue_more_ops(&mut self) -> bool { 86 | let mut cx = Context::from_waker(Waker::noop()); 87 | Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready() 88 | } 89 | 90 | /// Enqueue `item` to be written. 91 | #[expect(clippy::missing_panics_doc)] 92 | pub fn enqueue_write_op(&mut self, item: &ClientOp) { 93 | let payload = self.encoder.encode(item); 94 | Pin::new(&mut self.socket) 95 | .start_send(Message::binary(payload)) 96 | .unwrap(); 97 | self.should_flush = true; 98 | } 99 | 100 | pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { 101 | Pin::new(&mut self.socket) 102 | .poll_flush(cx) 103 | .map_err(websockets_error_to_io) 104 | } 105 | 106 | /// Flush any buffered writes to the connection 107 | /// 108 | /// # Errors 109 | /// 110 | /// Returns an error if flushing fails 111 | pub async fn flush(&mut self) -> io::Result<()> { 112 | future::poll_fn(|cx| self.poll_flush(cx)).await 113 | } 114 | 115 | /// Shutdown the connection 116 | /// 117 | /// # Errors 118 | /// 119 | /// Returns an error if shutting down the connection fails. 120 | /// Implementations usually ignore this error. 121 | pub async fn shutdown(&mut self) -> io::Result<()> { 122 | future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx)) 123 | .await 124 | .map_err(websockets_error_to_io) 125 | } 126 | } 127 | 128 | #[derive(Debug, thiserror::Error)] 129 | pub enum WebsocketReadError { 130 | #[error("decoder")] 131 | Decoder(#[source] FrameDecoderError), 132 | #[error("io")] 133 | Io(#[source] io::Error), 134 | #[error("closed")] 135 | Closed, 136 | } 137 | 138 | fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error { 139 | match err { 140 | tokio_websockets::Error::Io(err) => err, 141 | err => io::Error::other(err), 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /watermelon-net/src/future.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | pin::Pin, 3 | task::{Context, Poll}, 4 | }; 5 | 6 | use futures_core::Stream; 7 | 8 | #[derive(Debug, Clone)] 9 | pub(crate) struct IterToStream { 10 | pub(crate) iter: I, 11 | } 12 | 13 | impl Unpin for IterToStream {} 14 | 15 | impl Stream for IterToStream { 16 | type Item = I::Item; 17 | 18 | fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 19 | Poll::Ready(self.iter.next()) 20 | } 21 | 22 | fn size_hint(&self) -> (usize, Option) { 23 | self.iter.size_hint() 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /watermelon-net/src/happy_eyeballs.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::{self, Future}, 3 | io, 4 | net::SocketAddr, 5 | pin::{Pin, pin}, 6 | task::{Context, Poll}, 7 | time::Duration, 8 | }; 9 | 10 | use futures_core::{Stream, stream::FusedStream}; 11 | use pin_project_lite::pin_project; 12 | use tokio::{ 13 | net::{self, TcpStream}, 14 | task::JoinSet, 15 | time::{self, Sleep}, 16 | }; 17 | use watermelon_proto::{Host, ServerAddr}; 18 | 19 | use crate::future::IterToStream; 20 | 21 | const CONN_ATTEMPT_DELAY: Duration = Duration::from_millis(250); 22 | 23 | /// Connects to an address and returns a [`TcpStream`]. 24 | /// 25 | /// If the given address is an ip, this just uses [`TcpStream::connect`]. Otherwise, if a host is 26 | /// given, the [Happy Eyeballs] protocol is being used. 27 | /// 28 | /// [Happy Eyeballs]: https://en.wikipedia.org/wiki/Happy_Eyeballs 29 | /// 30 | /// # Errors 31 | /// 32 | /// It returns an error if it is not possible to connect to any host. 33 | pub async fn connect(addr: &ServerAddr) -> io::Result { 34 | match addr.host() { 35 | Host::Ip(ip) => TcpStream::connect(SocketAddr::new(*ip, addr.port())).await, 36 | Host::Dns(host) => { 37 | let addrs = net::lookup_host((&**host, addr.port())).await?; 38 | 39 | let mut happy_eyeballs = pin!(HappyEyeballs::new(IterToStream { iter: addrs })); 40 | let mut last_err = None; 41 | loop { 42 | match future::poll_fn(|cx| happy_eyeballs.as_mut().poll_next(cx)).await { 43 | Some(Ok(conn)) => return Ok(conn), 44 | Some(Err(err)) => last_err = Some(err), 45 | None => { 46 | return Err(last_err.unwrap_or_else(|| { 47 | io::Error::new( 48 | io::ErrorKind::InvalidInput, 49 | "could not resolve to any address", 50 | ) 51 | })); 52 | } 53 | } 54 | } 55 | } 56 | } 57 | } 58 | 59 | pin_project! { 60 | #[project = HappyEyeballsProj] 61 | struct HappyEyeballs { 62 | #[pin] 63 | dns: Option, 64 | dns_received: Vec, 65 | connecting: JoinSet>, 66 | last_attempted: Option, 67 | #[pin] 68 | next_attempt_delay: Option, 69 | } 70 | } 71 | 72 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 73 | enum LastAttempted { 74 | Ipv4, 75 | Ipv6, 76 | } 77 | 78 | impl HappyEyeballs { 79 | fn new(dns: D) -> Self { 80 | Self { 81 | dns: Some(dns), 82 | dns_received: Vec::new(), 83 | connecting: JoinSet::new(), 84 | last_attempted: None, 85 | next_attempt_delay: None, 86 | } 87 | } 88 | } 89 | 90 | impl HappyEyeballsProj<'_, D> { 91 | fn next_dns_record(&mut self) -> Option { 92 | if self.dns_received.is_empty() { 93 | return None; 94 | } 95 | 96 | let next_kind = self 97 | .last_attempted 98 | .map_or(LastAttempted::Ipv6, LastAttempted::opposite); 99 | for i in 0..self.dns_received.len() { 100 | if LastAttempted::from_addr(self.dns_received[i]) == next_kind { 101 | *self.last_attempted = Some(next_kind); 102 | return Some(self.dns_received.remove(i)); 103 | } 104 | } 105 | 106 | let record = self.dns_received.remove(0); 107 | *self.last_attempted = Some(LastAttempted::from_addr(record)); 108 | Some(record) 109 | } 110 | } 111 | 112 | impl Stream for HappyEyeballs 113 | where 114 | D: Stream, 115 | { 116 | type Item = io::Result; 117 | 118 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 119 | let mut this = self.project(); 120 | 121 | while let Some(dns) = this.dns.as_mut().as_pin_mut() { 122 | match dns.poll_next(cx) { 123 | Poll::Pending => break, 124 | Poll::Ready(Some(record)) => this.dns_received.push(record), 125 | Poll::Ready(None) => this.dns.set(None), 126 | } 127 | } 128 | 129 | loop { 130 | match this.connecting.poll_join_next(cx) { 131 | Poll::Pending => { 132 | if let Some(next_attempt_delay) = this.next_attempt_delay.as_mut().as_pin_mut() 133 | { 134 | match next_attempt_delay.poll(cx) { 135 | Poll::Pending => break, 136 | Poll::Ready(()) => this.next_attempt_delay.set(None), 137 | } 138 | } 139 | } 140 | Poll::Ready(Some(maybe_conn)) => { 141 | return Poll::Ready(Some(maybe_conn.expect("connect panicked"))); 142 | } 143 | Poll::Ready(None) => {} 144 | } 145 | 146 | let Some(record) = this.next_dns_record() else { 147 | this.next_attempt_delay.set(None); 148 | break; 149 | }; 150 | let conn_fut = TcpStream::connect(record); 151 | this.connecting.spawn(conn_fut); 152 | this.next_attempt_delay 153 | .set(Some(time::sleep(CONN_ATTEMPT_DELAY))); 154 | } 155 | 156 | if this.dns.is_none() && this.connecting.is_empty() && this.next_attempt_delay.is_none() { 157 | Poll::Ready(None) 158 | } else { 159 | Poll::Pending 160 | } 161 | } 162 | 163 | fn size_hint(&self) -> (usize, Option) { 164 | let (mut len, mut max) = self.dns.as_ref().map_or((0, Some(0)), Stream::size_hint); 165 | len = len.saturating_add(self.dns_received.len() + self.connecting.len()); 166 | if let Some(max) = &mut max { 167 | *max = max.saturating_add(self.dns_received.len() + self.connecting.len()); 168 | } 169 | (len, max) 170 | } 171 | } 172 | 173 | impl FusedStream for HappyEyeballs 174 | where 175 | D: Stream, 176 | { 177 | fn is_terminated(&self) -> bool { 178 | self.dns.is_none() && self.connecting.is_empty() && self.next_attempt_delay.is_none() 179 | } 180 | } 181 | 182 | impl LastAttempted { 183 | fn from_addr(addr: SocketAddr) -> Self { 184 | match addr { 185 | SocketAddr::V4(_) => Self::Ipv4, 186 | SocketAddr::V6(_) => Self::Ipv6, 187 | } 188 | } 189 | 190 | fn opposite(self) -> Self { 191 | match self { 192 | Self::Ipv4 => Self::Ipv6, 193 | Self::Ipv6 => Self::Ipv4, 194 | } 195 | } 196 | } 197 | 198 | #[cfg(test)] 199 | mod tests { 200 | use std::{ 201 | net::{Ipv4Addr, Ipv6Addr}, 202 | pin::pin, 203 | }; 204 | 205 | use futures_util::{StreamExt as _, stream}; 206 | use tokio::net::TcpListener; 207 | 208 | use super::HappyEyeballs; 209 | 210 | #[tokio::test] 211 | async fn happy_eyeballs_prefer_v6() { 212 | let ipv4_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); 213 | let ipv6_listener = TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await.unwrap(); 214 | 215 | let ipv4_addr = ipv4_listener.local_addr().unwrap(); 216 | let ipv6_addr = ipv6_listener.local_addr().unwrap(); 217 | 218 | tokio::spawn(async move { while ipv6_listener.accept().await.is_ok() {} }); 219 | 220 | let addrs = stream::iter([ipv4_addr, ipv6_addr]); 221 | let mut happy_eyeballs = pin!(HappyEyeballs::new(addrs)); 222 | let conn = happy_eyeballs.next().await.unwrap().unwrap(); 223 | assert!(conn.peer_addr().unwrap().is_ipv6()); 224 | } 225 | 226 | #[tokio::test] 227 | async fn happy_eyeballs_fallback_v4() { 228 | let ipv4_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); 229 | let ipv6_listener = TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await.unwrap(); 230 | 231 | let ipv4_addr = ipv4_listener.local_addr().unwrap(); 232 | let ipv6_addr = ipv6_listener.local_addr().unwrap(); 233 | 234 | drop(ipv6_listener); 235 | tokio::spawn(async move { while ipv4_listener.accept().await.is_ok() {} }); 236 | 237 | let addrs = stream::iter([ipv4_addr, ipv6_addr]); 238 | let mut happy_eyeballs = pin!(HappyEyeballs::new(addrs)); 239 | let _v6_failure = happy_eyeballs.next().await; 240 | let conn = happy_eyeballs.next().await.unwrap().unwrap(); 241 | assert!(conn.peer_addr().unwrap().is_ipv4()); 242 | } 243 | 244 | #[tokio::test] 245 | async fn happy_eyeballs_only_v4_available() { 246 | let ipv4_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); 247 | 248 | let ipv4_addr = ipv4_listener.local_addr().unwrap(); 249 | 250 | tokio::spawn(async move { while ipv4_listener.accept().await.is_ok() {} }); 251 | 252 | let addrs = stream::iter([ipv4_addr]); 253 | let mut happy_eyeballs = pin!(HappyEyeballs::new(addrs)); 254 | let conn = happy_eyeballs.next().await.unwrap().unwrap(); 255 | assert!(conn.peer_addr().unwrap().is_ipv4()); 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /watermelon-net/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_code)] 2 | 3 | #[cfg(feature = "websocket")] 4 | pub use self::connection::WebsocketConnection; 5 | pub use self::connection::{Connection, StreamingConnection, connect as proto_connect}; 6 | pub use self::happy_eyeballs::connect as connect_tcp; 7 | 8 | mod connection; 9 | mod future; 10 | mod happy_eyeballs; 11 | 12 | pub mod error { 13 | #[cfg(feature = "websocket")] 14 | pub use super::connection::WebsocketReadError; 15 | pub use super::connection::{ConnectError, ConnectionReadError, StreamingReadError}; 16 | } 17 | -------------------------------------------------------------------------------- /watermelon-nkeys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon-nkeys" 3 | version = "0.1.4" 4 | description = "Minimal NKeys implementation for NATS client authentication" 5 | categories = ["parser-implementations", "cryptography"] 6 | keywords = ["nats", "nkey"] 7 | edition.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | rust-version.workspace = true 11 | 12 | [dependencies] 13 | aws-lc-rs = { version = "1.12.2", default-features = false, features = ["aws-lc-sys", "prebuilt-nasm"], optional = true } 14 | ring = { version = "0.17", optional = true } 15 | crc = "3.2.1" 16 | thiserror = "2" 17 | data-encoding = { version = "2.7.0", default-features = false } 18 | 19 | [dev-dependencies] 20 | claims = "0.8" 21 | 22 | [features] 23 | default = ["aws-lc-rs"] 24 | aws-lc-rs = ["dep:aws-lc-rs"] 25 | ring = ["dep:ring"] 26 | fips = ["aws-lc-rs", "aws-lc-rs/fips"] 27 | 28 | [lints] 29 | workspace = true 30 | -------------------------------------------------------------------------------- /watermelon-nkeys/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /watermelon-nkeys/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /watermelon-nkeys/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /watermelon-nkeys/src/crc.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, PartialEq, Eq)] 2 | pub(crate) struct Crc16(u16); 3 | 4 | impl Crc16 { 5 | pub(crate) fn compute(buf: &[u8]) -> Self { 6 | Self(crc::Crc::::new(&crc::CRC_16_XMODEM).checksum(buf)) 7 | } 8 | 9 | pub(crate) fn from_raw_encoded(val: [u8; 2]) -> Self { 10 | Self::from_raw(u16::from_le_bytes(val)) 11 | } 12 | 13 | pub(crate) fn from_raw(val: u16) -> Self { 14 | Self(val) 15 | } 16 | 17 | #[cfg_attr(not(test), expect(dead_code))] 18 | pub(crate) fn to_raw(&self) -> u16 { 19 | self.0 20 | } 21 | 22 | pub(crate) fn to_raw_encoded(&self) -> [u8; 2] { 23 | self.0.to_le_bytes() 24 | } 25 | } 26 | 27 | #[cfg(test)] 28 | mod tests { 29 | use super::Crc16; 30 | 31 | #[test] 32 | fn compute() { 33 | let input = [ 34 | 127, 237, 118, 35, 51, 69, 160, 148, 48, 70, 89, 182, 167, 81, 102, 237, 1, 143, 113, 35 | 171, 162, 163, 101, 161, 49, 2, 57, 163, 167, 13, 106, 97, 249, 213, 36 | ]; 37 | let crc = Crc16::compute(&input); 38 | assert_eq!(14592, crc.to_raw()); 39 | assert_eq!(crc, Crc16::from_raw(crc.to_raw())); 40 | assert_eq!(crc, Crc16::from_raw_encoded(crc.to_raw_encoded())); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /watermelon-nkeys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_code)] 2 | 3 | pub use self::seed::{KeyPair, KeyPairFromSeedError, PublicKey, Signature}; 4 | 5 | mod crc; 6 | mod seed; 7 | -------------------------------------------------------------------------------- /watermelon-nkeys/src/seed.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Debug, Display}; 2 | 3 | #[cfg(feature = "aws-lc-rs")] 4 | use aws_lc_rs::signature::{Ed25519KeyPair, KeyPair as _, Signature as LlSignature}; 5 | use data_encoding::{BASE32_NOPAD, BASE64URL_NOPAD}; 6 | #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] 7 | use ring::signature::{Ed25519KeyPair, KeyPair as _, Signature as LlSignature}; 8 | 9 | #[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))] 10 | compile_error!("Please enable the `aws-lc-rs` or the `ring` feature"); 11 | 12 | use crate::crc::Crc16; 13 | 14 | const SEED_PREFIX_BYTE: u8 = 18 << 3; 15 | 16 | /// A `NKey` private/public key pair. 17 | pub struct KeyPair { 18 | kind: u8, 19 | key: Ed25519KeyPair, 20 | } 21 | 22 | /// The public key within an `NKey` private/public key pair. 23 | #[derive(Debug)] 24 | pub struct PublicKey<'a>(&'a KeyPair); 25 | 26 | /// An error encountered while decoding an `NKey`. 27 | #[derive(Debug, thiserror::Error)] 28 | pub enum KeyPairFromSeedError { 29 | /// The string rapresentation of the seed has an invalid length. 30 | #[error("invalid length of the seed's string the string rapresentation")] 31 | InvalidSeedLength, 32 | /// The string rapresentation of the seed contains characters that are not part of the base32 dictionary. 33 | #[error("the seed contains non-base32 characters")] 34 | InvalidBase32, 35 | /// The decoded base32 rapresentation of the seed has an invalid length. 36 | #[error("invalid base32 decoded seed length")] 37 | InvalidRawSeedLength, 38 | /// The CRC does not match the crc calculated for the seed payload. 39 | #[error("invalid CRC")] 40 | BadCrc, 41 | /// The prefix for the seed is invalid 42 | #[error("invalid seed prefix")] 43 | InvalidPrefix, 44 | /// the seed could not be decoded by the crypto backend 45 | #[error("decode error")] 46 | DecodeError, 47 | } 48 | 49 | /// A payload signed via a [`KeyPair`]. 50 | /// 51 | /// Obtained from [`KeyPair::sign`]. 52 | pub struct Signature(LlSignature); 53 | 54 | impl KeyPair { 55 | /// Decode a key from an `NKey` seed. 56 | /// 57 | /// # Errors 58 | /// 59 | /// Returns an error if `seed` is invalid. 60 | #[expect( 61 | clippy::missing_panics_doc, 62 | reason = "the array `TryInto` calls cannot panic" 63 | )] 64 | pub fn from_encoded_seed(seed: &str) -> Result { 65 | if seed.len() != 58 { 66 | return Err(KeyPairFromSeedError::InvalidSeedLength); 67 | } 68 | 69 | let mut full_raw_seed = [0; 36]; 70 | let len = BASE32_NOPAD 71 | .decode_mut(seed.as_bytes(), &mut full_raw_seed) 72 | .map_err(|_| KeyPairFromSeedError::InvalidBase32)?; 73 | if len != full_raw_seed.len() { 74 | return Err(KeyPairFromSeedError::InvalidRawSeedLength); 75 | } 76 | 77 | let (raw_seed, crc) = full_raw_seed.split_at(full_raw_seed.len() - 2); 78 | let raw_seed_crc = Crc16::compute(raw_seed); 79 | let expected_crc = Crc16::from_raw_encoded(crc.try_into().unwrap()); 80 | if raw_seed_crc != expected_crc { 81 | return Err(KeyPairFromSeedError::BadCrc); 82 | } 83 | 84 | Self::from_raw_seed(raw_seed.try_into().unwrap()) 85 | } 86 | 87 | fn from_raw_seed(raw_seed: [u8; 34]) -> Result { 88 | if raw_seed[0] & 248 != SEED_PREFIX_BYTE { 89 | println!("{:x}", raw_seed[0]); 90 | return Err(KeyPairFromSeedError::InvalidPrefix); 91 | } 92 | 93 | let kind = raw_seed[1]; 94 | 95 | let key = Ed25519KeyPair::from_seed_unchecked(&raw_seed[2..]) 96 | .map_err(|_| KeyPairFromSeedError::DecodeError)?; 97 | Ok(Self { kind, key }) 98 | } 99 | 100 | #[must_use] 101 | pub fn public_key(&self) -> PublicKey<'_> { 102 | PublicKey(self) 103 | } 104 | 105 | #[must_use] 106 | pub fn sign(&self, buf: &[u8]) -> Signature { 107 | Signature(self.key.sign(buf)) 108 | } 109 | } 110 | 111 | impl Debug for KeyPair { 112 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 113 | f.debug_struct("KeyPair") 114 | .field("kind", &self.kind) 115 | .finish_non_exhaustive() 116 | } 117 | } 118 | 119 | impl Display for Signature { 120 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 121 | Display::fmt(&BASE64URL_NOPAD.encode_display(self.0.as_ref()), f) 122 | } 123 | } 124 | 125 | impl Display for PublicKey<'_> { 126 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 127 | let mut full_raw_seed = [0; 36]; 128 | full_raw_seed[0] = SEED_PREFIX_BYTE; 129 | full_raw_seed[1] = self.0.kind; 130 | full_raw_seed[2..34].copy_from_slice(self.0.key.public_key().as_ref()); 131 | let crc = Crc16::compute(&full_raw_seed[..34]); 132 | full_raw_seed[34..36].copy_from_slice(&crc.to_raw_encoded()); 133 | Display::fmt(&BASE32_NOPAD.encode_display(&full_raw_seed), f) 134 | } 135 | } 136 | 137 | #[cfg(test)] 138 | mod tests { 139 | use claims::assert_matches; 140 | 141 | use super::{KeyPair, KeyPairFromSeedError}; 142 | 143 | #[test] 144 | fn sign() { 145 | let key = KeyPair::from_encoded_seed( 146 | "SAAPN4W3EG6KCJGUQTKTJ5GSB5NHK5CHAJL4DBGFUM3HHROI4XUEP4OBK4", 147 | ) 148 | .unwrap(); 149 | assert_eq!( 150 | "HuHkn4SHFW1ibjQzmqyNw8KUZDWB0bKciDbK7YmNyqyyvC3k4s0AqimAz6jMt0xhLqGAOyj30UaUol2xMVpsBQ", 151 | key.sign(b"fwD9iyDvqxpcj3ii").to_string() 152 | ); 153 | } 154 | 155 | #[test] 156 | fn gen_public_key() { 157 | let key = KeyPair::from_encoded_seed( 158 | "SAAPN4W3EG6KCJGUQTKTJ5GSB5NHK5CHAJL4DBGFUM3HHROI4XUEP4OBK4", 159 | ) 160 | .unwrap(); 161 | assert_eq!( 162 | "SAAJYMSGSUUUC3GAOKL2IFAAKQDV32K4X45HPCPC4EBM7F7N76HQGR4C2I", 163 | key.public_key().to_string() 164 | ); 165 | } 166 | 167 | #[test] 168 | fn invalid_len() { 169 | assert_matches!( 170 | KeyPair::from_encoded_seed(""), 171 | Err(KeyPairFromSeedError::InvalidSeedLength) 172 | ); 173 | } 174 | 175 | #[test] 176 | fn invalid_base32() { 177 | assert_matches!( 178 | KeyPair::from_encoded_seed( 179 | "SAAPN4W3EG6KCJGUQTKTJ5!#B5NHK5CHAJL4DBGFUM3HHROI4XUEP4OBK4" 180 | ), 181 | Err(KeyPairFromSeedError::InvalidBase32) 182 | ); 183 | } 184 | 185 | #[test] 186 | fn invalid_crc() { 187 | assert_matches!( 188 | KeyPair::from_encoded_seed( 189 | "FAAPN4W3EG6KCJGUQTKTJ5GSB5NHK5CHAJL4DBGFUM3HHROI4XUEP4OBK4" 190 | ), 191 | Err(KeyPairFromSeedError::BadCrc) 192 | ); 193 | } 194 | 195 | #[test] 196 | fn invalid_prefix() { 197 | assert_matches!( 198 | KeyPair::from_encoded_seed( 199 | "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" 200 | ), 201 | Err(KeyPairFromSeedError::InvalidPrefix) 202 | ); 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /watermelon-proto/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon-proto" 3 | version = "0.1.8" 4 | description = "#[no_std] NATS Core Sans-IO protocol implementation" 5 | categories = ["network-programming", "parser-implementations", "no-std"] 6 | keywords = ["nats", "client"] 7 | edition.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | rust-version.workspace = true 11 | 12 | [package.metadata.docs.rs] 13 | features = ["non-standard-zstd"] 14 | 15 | [dependencies] 16 | bytes = { version = "1", default-features = false } 17 | bytestring = { version = "1", default-features = false, features = ["serde"] } 18 | url = { version = "2.5.3", default-features = false, features = ["serde"] } 19 | percent-encoding = { version = "2", default-features = false, features = ["alloc"] } 20 | memchr = { version = "2.4", default-features = false } 21 | unicase = "2.7" 22 | 23 | serde = { version = "1.0.113", default-features = false, features = ["alloc", "derive"] } 24 | serde_json = { version = "1", default-features = false, features = ["alloc"] } 25 | itoa = "1.0.13" 26 | 27 | thiserror = { version = "2", default-features = false } 28 | 29 | [dev-dependencies] 30 | claims = "0.8" 31 | 32 | [features] 33 | default = ["std"] 34 | std = ["bytes/std", "url/std", "percent-encoding/std", "memchr/std", "serde/std", "serde_json/std", "thiserror/std"] 35 | non-standard-zstd = [] 36 | 37 | [lints] 38 | workspace = true 39 | -------------------------------------------------------------------------------- /watermelon-proto/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /watermelon-proto/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /watermelon-proto/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /watermelon-proto/src/connect.rs: -------------------------------------------------------------------------------- 1 | use alloc::string::String; 2 | 3 | use serde::Serialize; 4 | 5 | #[derive(Debug, Serialize)] 6 | #[allow(clippy::struct_excessive_bools)] 7 | pub struct Connect { 8 | pub verbose: bool, 9 | pub pedantic: bool, 10 | #[serde(rename = "tls_required")] 11 | pub require_tls: bool, 12 | pub auth_token: Option, 13 | #[serde(rename = "user")] 14 | pub username: Option, 15 | #[serde(rename = "pass")] 16 | pub password: Option, 17 | #[serde(rename = "name")] 18 | pub client_name: Option, 19 | #[serde(rename = "lang")] 20 | pub client_lang: &'static str, 21 | #[serde(rename = "version")] 22 | pub client_version: &'static str, 23 | pub protocol: u8, 24 | pub echo: bool, 25 | #[serde(rename = "sig")] 26 | pub signature: Option, 27 | pub jwt: Option, 28 | #[serde(rename = "no_responders")] 29 | pub supports_no_responders: bool, 30 | #[serde(rename = "headers")] 31 | pub supports_headers: bool, 32 | pub nkey: Option, 33 | 34 | #[serde(flatten)] 35 | pub non_standard: NonStandardConnect, 36 | } 37 | 38 | #[derive(Debug, Serialize)] 39 | #[non_exhaustive] 40 | pub struct NonStandardConnect { 41 | #[cfg(feature = "non-standard-zstd")] 42 | #[serde( 43 | rename = "m4ss_zstd", 44 | skip_serializing_if = "skip_serializing_if_false" 45 | )] 46 | pub zstd: bool, 47 | } 48 | 49 | #[allow(clippy::derivable_impls)] 50 | impl Default for NonStandardConnect { 51 | fn default() -> Self { 52 | Self { 53 | #[cfg(feature = "non-standard-zstd")] 54 | zstd: false, 55 | } 56 | } 57 | } 58 | 59 | #[cfg(feature = "non-standard-zstd")] 60 | #[allow(clippy::trivially_copy_pass_by_ref)] 61 | fn skip_serializing_if_false(val: &bool) -> bool { 62 | !*val 63 | } 64 | -------------------------------------------------------------------------------- /watermelon-proto/src/headers/map.rs: -------------------------------------------------------------------------------- 1 | use alloc::{ 2 | collections::{BTreeMap, btree_map::Entry}, 3 | vec, 4 | vec::Vec, 5 | }; 6 | use core::{ 7 | fmt::{self, Debug}, 8 | mem, 9 | }; 10 | 11 | use super::{HeaderName, HeaderValue}; 12 | 13 | static EMPTY_HEADERS: OneOrMany = OneOrMany::Many(Vec::new()); 14 | 15 | /// A set of NATS headers 16 | /// 17 | /// [`HeaderMap`] is a multimap of [`HeaderName`]. 18 | #[derive(Clone, PartialEq, Eq)] 19 | pub struct HeaderMap { 20 | headers: BTreeMap, 21 | len: usize, 22 | } 23 | 24 | #[derive(Clone, PartialEq, Eq)] 25 | enum OneOrMany { 26 | One(HeaderValue), 27 | Many(Vec), 28 | } 29 | 30 | impl HeaderMap { 31 | /// Create an empty `HeaderMap` 32 | /// 33 | /// The map will be created without any capacity. This function will not allocate. 34 | /// 35 | /// Consider using the [`FromIterator`], [`Extend`] implementations if the final 36 | /// length is known upfront. 37 | #[must_use] 38 | pub const fn new() -> Self { 39 | Self { 40 | headers: BTreeMap::new(), 41 | len: 0, 42 | } 43 | } 44 | 45 | pub fn get(&self, name: &HeaderName) -> Option<&HeaderValue> { 46 | self.get_all(name).next() 47 | } 48 | 49 | pub fn get_all<'a>( 50 | &'a self, 51 | name: &HeaderName, 52 | ) -> impl DoubleEndedIterator + use<'a> { 53 | self.headers.get(name).unwrap_or(&EMPTY_HEADERS).iter() 54 | } 55 | 56 | pub fn insert(&mut self, name: HeaderName, value: HeaderValue) { 57 | if let Some(prev) = self.headers.insert(name, OneOrMany::One(value)) { 58 | self.len -= prev.len(); 59 | } 60 | self.len += 1; 61 | } 62 | 63 | pub fn append(&mut self, name: HeaderName, value: HeaderValue) { 64 | match self.headers.entry(name) { 65 | Entry::Vacant(vacant) => { 66 | vacant.insert(OneOrMany::One(value)); 67 | } 68 | Entry::Occupied(mut occupied) => { 69 | occupied.get_mut().push(value); 70 | } 71 | } 72 | self.len += 1; 73 | } 74 | 75 | pub fn remove(&mut self, name: &HeaderName) { 76 | if let Some(prev) = self.headers.remove(name) { 77 | self.len -= prev.len(); 78 | } 79 | } 80 | 81 | /// Returns the number of keys stored in the map 82 | /// 83 | /// This number will be less than or equal to [`HeaderMap::len`]. 84 | #[must_use] 85 | pub fn keys_len(&self) -> usize { 86 | self.headers.len() 87 | } 88 | 89 | /// Returns the number of headers stored in the map 90 | /// 91 | /// This number represents the total number of **values** stored in the map. 92 | /// This number can be greater than or equal to the number of **keys** stored. 93 | #[must_use] 94 | pub fn len(&self) -> usize { 95 | self.len 96 | } 97 | 98 | /// Returns true if the map contains no elements 99 | #[must_use] 100 | pub fn is_empty(&self) -> bool { 101 | self.headers.is_empty() 102 | } 103 | 104 | /// Clear the map, removing all key-value pairs. Keeps the allocated memory for reuse 105 | pub fn clear(&mut self) { 106 | self.headers.clear(); 107 | self.len = 0; 108 | } 109 | 110 | #[cfg(test)] 111 | fn keys(&self) -> impl Iterator { 112 | self.headers.keys() 113 | } 114 | 115 | pub(crate) fn iter( 116 | &self, 117 | ) -> impl DoubleEndedIterator)> 118 | { 119 | self.headers 120 | .iter() 121 | .map(|(name, value)| (name, value.iter())) 122 | } 123 | } 124 | 125 | impl Debug for HeaderMap { 126 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 127 | f.debug_tuple("HeaderMap") 128 | .field(&self.headers) 129 | // FIXME: switch to `finish_non_exhaustive` 130 | .finish() 131 | } 132 | } 133 | 134 | impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap { 135 | fn from_iter>(iter: I) -> Self { 136 | let mut this = Self::new(); 137 | this.extend(iter); 138 | this 139 | } 140 | } 141 | 142 | impl Extend<(HeaderName, HeaderValue)> for HeaderMap { 143 | fn extend>(&mut self, iter: T) { 144 | iter.into_iter().for_each(|(name, value)| { 145 | self.append(name, value); 146 | }); 147 | } 148 | } 149 | 150 | impl Default for HeaderMap { 151 | fn default() -> Self { 152 | Self::new() 153 | } 154 | } 155 | 156 | impl Debug for OneOrMany { 157 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 158 | f.debug_set().entries(self.iter()).finish() 159 | } 160 | } 161 | 162 | impl OneOrMany { 163 | fn len(&self) -> usize { 164 | match self { 165 | Self::One(_) => 1, 166 | Self::Many(vec) => vec.len(), 167 | } 168 | } 169 | 170 | fn push(&mut self, item: HeaderValue) { 171 | match self { 172 | Self::One(current_item) => { 173 | let current_item = 174 | mem::replace(current_item, HeaderValue::from_static("replacing")); 175 | *self = Self::Many(vec![current_item, item]); 176 | } 177 | Self::Many(vec) => { 178 | debug_assert!(!vec.is_empty(), "OneOrMany can't be empty"); 179 | vec.push(item); 180 | } 181 | } 182 | } 183 | 184 | fn iter(&self) -> impl DoubleEndedIterator { 185 | // This implementation may look odd, but it implements `TrustedLen`, 186 | // so the Iterator is efficient to collect. 187 | match self { 188 | Self::One(one) => Iterator::chain(Some(one).into_iter(), &[]), 189 | Self::Many(many) => Iterator::chain(None.into_iter(), many), 190 | } 191 | } 192 | } 193 | 194 | #[cfg(test)] 195 | mod tests { 196 | use alloc::{vec, vec::Vec}; 197 | 198 | use crate::headers::{HeaderName, HeaderValue}; 199 | 200 | use super::HeaderMap; 201 | 202 | #[test] 203 | fn manual() { 204 | let mut headers = HeaderMap::new(); 205 | headers.append( 206 | HeaderName::from_static("Nats-Message-Id"), 207 | HeaderValue::from_static("abcd"), 208 | ); 209 | headers.append( 210 | HeaderName::from_static("Nats-Sequence"), 211 | HeaderValue::from_static("1"), 212 | ); 213 | headers.append( 214 | HeaderName::from_static("Nats-Message-Id"), 215 | HeaderValue::from_static("1234"), 216 | ); 217 | headers.append( 218 | HeaderName::from_static("Nats-Time-Stamp"), 219 | HeaderValue::from_static("0"), 220 | ); 221 | headers.remove(&HeaderName::from_static("Nats-Time-Stamp")); 222 | 223 | verify_header_map(&headers); 224 | } 225 | 226 | #[test] 227 | fn collect() { 228 | let headers = [ 229 | ( 230 | HeaderName::from_static("Nats-Message-Id"), 231 | HeaderValue::from_static("abcd"), 232 | ), 233 | ( 234 | HeaderName::from_static("Nats-Sequence"), 235 | HeaderValue::from_static("1"), 236 | ), 237 | ( 238 | HeaderName::from_static("Nats-Message-Id"), 239 | HeaderValue::from_static("1234"), 240 | ), 241 | ] 242 | .into_iter() 243 | .collect::(); 244 | 245 | verify_header_map(&headers); 246 | } 247 | 248 | fn verify_header_map(headers: &HeaderMap) { 249 | assert_eq!( 250 | [ 251 | HeaderName::from_static("Nats-Message-Id"), 252 | HeaderName::from_static("Nats-Sequence") 253 | ] 254 | .as_slice(), 255 | headers.keys().cloned().collect::>().as_slice() 256 | ); 257 | 258 | let raw_headers = headers 259 | .iter() 260 | .map(|(name, values)| (name.clone(), values.cloned().collect::>())) 261 | .collect::>(); 262 | assert_eq!( 263 | [ 264 | ( 265 | HeaderName::from_static("Nats-Message-Id"), 266 | vec![ 267 | HeaderValue::from_static("abcd"), 268 | HeaderValue::from_static("1234") 269 | ] 270 | ), 271 | ( 272 | HeaderName::from_static("Nats-Sequence"), 273 | vec![HeaderValue::from_static("1")] 274 | ), 275 | ] 276 | .as_slice(), 277 | raw_headers.as_slice(), 278 | ); 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /watermelon-proto/src/headers/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::map::HeaderMap; 2 | pub use self::name::HeaderName; 3 | pub use self::value::HeaderValue; 4 | 5 | mod map; 6 | mod name; 7 | mod value; 8 | 9 | pub mod error { 10 | pub use super::name::HeaderNameValidateError; 11 | pub use super::value::HeaderValueValidateError; 12 | } 13 | -------------------------------------------------------------------------------- /watermelon-proto/src/headers/name.rs: -------------------------------------------------------------------------------- 1 | use alloc::string::String; 2 | use core::{ 3 | fmt::{self, Display}, 4 | ops::Deref, 5 | }; 6 | use unicase::UniCase; 7 | 8 | use bytestring::ByteString; 9 | 10 | /// A string that can be used to represent an header name 11 | /// 12 | /// `HeaderName` contains a string that is guaranteed [^1] to 13 | /// contain a valid header name that meets the following requirements: 14 | /// 15 | /// * The value is not empty 16 | /// * The value has a length less than or equal to 64 [^2] 17 | /// * The value does not contain any whitespace characters or `:` 18 | /// 19 | /// `HeaderName` can be constructed from [`HeaderName::from_static`] 20 | /// or any of the `TryFrom` implementations. 21 | /// 22 | /// [^1]: Because [`HeaderName::from_dangerous_value`] is safe to call, 23 | /// unsafe code must not assume any of the above invariants. 24 | /// [^2]: Messages coming from the NATS server are allowed to violate this rule. 25 | #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] 26 | pub struct HeaderName(UniCase); 27 | 28 | impl HeaderName { 29 | /// Client-defined unique identifier for a message that will be used by the server apply de-duplication within the configured Jetstream _Duplicate Window_ 30 | pub const MESSAGE_ID: Self = Self::new_internal("Nats-Msg-Id"); 31 | /// Have Jetstream assert that the published message is received by the expected stream 32 | pub const EXPECTED_STREAM: Self = Self::new_internal("Nats-Expected-Stream"); 33 | /// Have Jetstream assert that the last expected [`HeaderName::MESSAGE_ID`] matches this ID 34 | pub const EXPECTED_LAST_MESSAGE_ID: Self = Self::new_internal("Nats-Expected-Last-Msg-Id"); 35 | /// Have Jetstream assert that the last sequence ID matches this ID 36 | pub const EXPECTED_LAST_SEQUENCE: Self = Self::new_internal("Nats-Expected-Last-Sequence"); 37 | /// Purge all prior messages in the stream (`all` value) or at the subject-level (`sub` value) 38 | pub const ROLLUP: Self = Self::new_internal("Nats-Rollup"); 39 | 40 | /// Name of the stream the message was republished from 41 | pub const STREAM: Self = Self::new_internal("Nats-Stream"); 42 | /// Original subject to which the message was republished from 43 | pub const SUBJECT: Self = Self::new_internal("Nats-Subject"); 44 | /// Original sequence ID the message was republished from 45 | pub const SEQUENCE: Self = Self::new_internal("Nats-Sequence"); 46 | /// Last sequence ID of the message having the same subject, or zero if this is the first message for the subject 47 | pub const LAST_SEQUENCE: Self = Self::new_internal("Nats-Last-Sequence"); 48 | /// The original RFC3339 timestamp of the message 49 | pub const TIMESTAMP: Self = Self::new_internal("Nats-Time-Stamp"); 50 | 51 | /// Origin stream name, subject, sequence number, subject filter and destination transform of the message being sourced 52 | pub const STREAM_SOURCE: Self = Self::new_internal("Nats-Stream-Source"); 53 | 54 | /// Size of the message payload in bytes for an headers-only message 55 | pub const MESSAGE_SIZE: Self = Self::new_internal("Nats-Msg-Size"); 56 | 57 | /// Construct `HeaderName` from a static string 58 | /// 59 | /// # Panics 60 | /// 61 | /// Will panic if `value` isn't a valid `HeaderName` 62 | #[must_use] 63 | pub fn from_static(value: &'static str) -> Self { 64 | Self::try_from(ByteString::from_static(value)).expect("invalid HeaderName") 65 | } 66 | 67 | /// Construct a `HeaderName` from a string, without checking invariants 68 | /// 69 | /// This method bypasses invariants checks implemented by [`HeaderName::from_static`] 70 | /// and all `TryFrom` implementations. 71 | /// 72 | /// # Security 73 | /// 74 | /// While calling this method can eliminate the runtime performance cost of 75 | /// checking the string, constructing `HeaderName` with an invalid string and 76 | /// then calling the NATS server with it can cause serious security issues. 77 | /// When in doubt use the [`HeaderName::from_static`] or any of the `TryFrom` 78 | /// implementations. 79 | #[expect( 80 | clippy::missing_panics_doc, 81 | reason = "The header validation is only made in debug" 82 | )] 83 | #[must_use] 84 | pub fn from_dangerous_value(value: ByteString) -> Self { 85 | if cfg!(debug_assertions) { 86 | if let Err(err) = validate_header_name(&value) { 87 | panic!("HeaderName {value:?} isn't valid {err:?}"); 88 | } 89 | } 90 | Self(UniCase::new(value)) 91 | } 92 | 93 | const fn new_internal(value: &'static str) -> Self { 94 | if value.is_ascii() { 95 | Self(UniCase::ascii(ByteString::from_static(value))) 96 | } else { 97 | Self(UniCase::unicode(ByteString::from_static(value))) 98 | } 99 | } 100 | 101 | #[must_use] 102 | pub fn as_str(&self) -> &str { 103 | &self.0 104 | } 105 | } 106 | 107 | impl Display for HeaderName { 108 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 109 | Display::fmt(&self.0, f) 110 | } 111 | } 112 | 113 | impl TryFrom for HeaderName { 114 | type Error = HeaderNameValidateError; 115 | 116 | fn try_from(value: ByteString) -> Result { 117 | validate_header_name(&value)?; 118 | Ok(Self::from_dangerous_value(value)) 119 | } 120 | } 121 | 122 | impl TryFrom for HeaderName { 123 | type Error = HeaderNameValidateError; 124 | 125 | fn try_from(value: String) -> Result { 126 | validate_header_name(&value)?; 127 | Ok(Self::from_dangerous_value(value.into())) 128 | } 129 | } 130 | 131 | impl From for ByteString { 132 | fn from(value: HeaderName) -> Self { 133 | value.0.into_inner() 134 | } 135 | } 136 | 137 | impl AsRef<[u8]> for HeaderName { 138 | fn as_ref(&self) -> &[u8] { 139 | self.as_str().as_bytes() 140 | } 141 | } 142 | 143 | impl AsRef for HeaderName { 144 | fn as_ref(&self) -> &str { 145 | self.as_str() 146 | } 147 | } 148 | 149 | impl Deref for HeaderName { 150 | type Target = str; 151 | 152 | fn deref(&self) -> &Self::Target { 153 | self.as_str() 154 | } 155 | } 156 | 157 | /// An error encountered while validating [`HeaderName`] 158 | #[derive(Debug, thiserror::Error)] 159 | pub enum HeaderNameValidateError { 160 | /// The value is empty 161 | #[error("HeaderName is empty")] 162 | Empty, 163 | /// The value has a length greater than 64 164 | #[error("HeaderName is too long")] 165 | TooLong, 166 | /// The value contains an Unicode whitespace character or `:` 167 | #[error("HeaderName contained an illegal whitespace character")] 168 | IllegalCharacter, 169 | } 170 | 171 | fn validate_header_name(header_name: &str) -> Result<(), HeaderNameValidateError> { 172 | if header_name.is_empty() { 173 | return Err(HeaderNameValidateError::Empty); 174 | } 175 | 176 | if header_name.len() > 64 { 177 | // This is an arbitrary limit, but I guess the server must also have one 178 | return Err(HeaderNameValidateError::TooLong); 179 | } 180 | 181 | if header_name.chars().any(|c| c.is_whitespace() || c == ':') { 182 | // The theoretical security limit is just ` `, `\t`, `\r`, `\n` and `:`. 183 | // Let's be more careful. 184 | return Err(HeaderNameValidateError::IllegalCharacter); 185 | } 186 | 187 | Ok(()) 188 | } 189 | 190 | #[cfg(test)] 191 | mod tests { 192 | use core::cmp::Ordering; 193 | 194 | use super::HeaderName; 195 | 196 | #[test] 197 | fn eq() { 198 | let cased = HeaderName::from_static("Nats-Message-Id"); 199 | let lowercase = HeaderName::from_static("nats-message-id"); 200 | assert_eq!(cased, lowercase); 201 | assert_eq!(cased.cmp(&lowercase), Ordering::Equal); 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /watermelon-proto/src/headers/value.rs: -------------------------------------------------------------------------------- 1 | use alloc::string::String; 2 | use core::{ 3 | fmt::{self, Display}, 4 | ops::Deref, 5 | }; 6 | 7 | use bytestring::ByteString; 8 | 9 | /// A string that can be used to represent an header value 10 | /// 11 | /// `HeaderValue` contains a string that is guaranteed [^1] to 12 | /// contain a valid header value that meets the following requirements: 13 | /// 14 | /// * The value is not empty 15 | /// * The value has a length less than or equal to 1024 [^2] 16 | /// * The value does not contain any whitespace characters 17 | /// 18 | /// `HeaderValue` can be constructed from [`HeaderValue::from_static`] 19 | /// or any of the `TryFrom` implementations. 20 | /// 21 | /// [^1]: Because [`HeaderValue::from_dangerous_value`] is safe to call, 22 | /// unsafe code must not assume any of the above invariants. 23 | /// [^2]: Messages coming from the NATS server are allowed to violate this rule. 24 | #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] 25 | pub struct HeaderValue(ByteString); 26 | 27 | impl HeaderValue { 28 | /// Construct `HeaderValue` from a static string 29 | /// 30 | /// # Panics 31 | /// 32 | /// Will panic if `value` isn't a valid `HeaderValue` 33 | #[must_use] 34 | pub fn from_static(value: &'static str) -> Self { 35 | Self::try_from(ByteString::from_static(value)).expect("invalid HeaderValue") 36 | } 37 | 38 | /// Construct a `HeaderValue` from a string, without checking invariants 39 | /// 40 | /// This method bypasses invariants checks implemented by [`HeaderValue::from_static`] 41 | /// and all `TryFrom` implementations. 42 | /// 43 | /// # Security 44 | /// 45 | /// While calling this method can eliminate the runtime performance cost of 46 | /// checking the string, constructing `HeaderValue` with an invalid string and 47 | /// then calling the NATS server with it can cause serious security issues. 48 | /// When in doubt use the [`HeaderValue::from_static`] or any of the `TryFrom` 49 | /// implementations. 50 | #[must_use] 51 | #[expect( 52 | clippy::missing_panics_doc, 53 | reason = "The header validation is only made in debug" 54 | )] 55 | pub fn from_dangerous_value(value: ByteString) -> Self { 56 | if cfg!(debug_assertions) { 57 | if let Err(err) = validate_header_value(&value) { 58 | panic!("HeaderValue {value:?} isn't valid {err:?}"); 59 | } 60 | } 61 | Self(value) 62 | } 63 | 64 | #[must_use] 65 | pub fn as_str(&self) -> &str { 66 | &self.0 67 | } 68 | } 69 | 70 | impl Display for HeaderValue { 71 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 72 | Display::fmt(&self.0, f) 73 | } 74 | } 75 | 76 | impl TryFrom for HeaderValue { 77 | type Error = HeaderValueValidateError; 78 | 79 | fn try_from(value: ByteString) -> Result { 80 | validate_header_value(&value)?; 81 | Ok(Self::from_dangerous_value(value)) 82 | } 83 | } 84 | 85 | impl TryFrom for HeaderValue { 86 | type Error = HeaderValueValidateError; 87 | 88 | fn try_from(value: String) -> Result { 89 | validate_header_value(&value)?; 90 | Ok(Self::from_dangerous_value(value.into())) 91 | } 92 | } 93 | 94 | impl From for ByteString { 95 | fn from(value: HeaderValue) -> Self { 96 | value.0 97 | } 98 | } 99 | 100 | impl AsRef<[u8]> for HeaderValue { 101 | fn as_ref(&self) -> &[u8] { 102 | self.as_str().as_bytes() 103 | } 104 | } 105 | 106 | impl AsRef for HeaderValue { 107 | fn as_ref(&self) -> &str { 108 | self.as_str() 109 | } 110 | } 111 | 112 | impl Deref for HeaderValue { 113 | type Target = str; 114 | 115 | fn deref(&self) -> &Self::Target { 116 | self.as_str() 117 | } 118 | } 119 | 120 | /// An error encountered while validating [`HeaderValue`] 121 | #[derive(Debug, thiserror::Error)] 122 | pub enum HeaderValueValidateError { 123 | /// The value is empty 124 | #[error("HeaderValue is empty")] 125 | Empty, 126 | /// The value has a length greater than 64 127 | #[error("HeaderValue is too long")] 128 | TooLong, 129 | /// The value contains an Unicode whitespace character 130 | #[error("HeaderValue contained an illegal whitespace character")] 131 | IllegalCharacter, 132 | } 133 | 134 | fn validate_header_value(header_value: &str) -> Result<(), HeaderValueValidateError> { 135 | if header_value.is_empty() { 136 | return Err(HeaderValueValidateError::Empty); 137 | } 138 | 139 | if header_value.len() > 1024 { 140 | // This is an arbitrary limit, but I guess the server must also have one 141 | return Err(HeaderValueValidateError::TooLong); 142 | } 143 | 144 | if header_value.chars().any(char::is_whitespace) { 145 | // The theoretical security limit is just ` `, `\t`, `\r` and `\n`. 146 | // Let's be more careful. 147 | return Err(HeaderValueValidateError::IllegalCharacter); 148 | } 149 | 150 | Ok(()) 151 | } 152 | -------------------------------------------------------------------------------- /watermelon-proto/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(feature = "std"), no_std)] 2 | #![forbid(unsafe_code)] 3 | 4 | extern crate alloc; 5 | 6 | pub use self::connect::{Connect, NonStandardConnect}; 7 | pub use self::message::{MessageBase, ServerMessage}; 8 | pub use self::queue_group::QueueGroup; 9 | pub use self::server_addr::{Host, Protocol, ServerAddr, Transport}; 10 | pub use self::server_info::{NonStandardServerInfo, ServerInfo}; 11 | pub use self::status_code::StatusCode; 12 | pub use self::subject::Subject; 13 | pub use self::subscription_id::SubscriptionId; 14 | 15 | mod connect; 16 | pub mod headers; 17 | mod message; 18 | pub mod proto; 19 | mod queue_group; 20 | mod server_addr; 21 | mod server_error; 22 | mod server_info; 23 | mod status_code; 24 | mod subject; 25 | mod subscription_id; 26 | #[cfg(test)] 27 | mod tests; 28 | mod util; 29 | 30 | pub mod error { 31 | pub use super::queue_group::QueueGroupValidateError; 32 | pub use super::server_addr::ServerAddrError; 33 | pub use super::server_error::ServerError; 34 | pub use super::status_code::StatusCodeError; 35 | pub use super::subject::SubjectValidateError; 36 | pub use super::util::ParseUintError; 37 | } 38 | -------------------------------------------------------------------------------- /watermelon-proto/src/message.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | 3 | use crate::{StatusCode, Subject, headers::HeaderMap, subscription_id::SubscriptionId}; 4 | 5 | #[derive(Debug, Clone, PartialEq, Eq)] 6 | pub struct MessageBase { 7 | pub subject: Subject, 8 | pub reply_subject: Option, 9 | pub headers: HeaderMap, 10 | pub payload: Bytes, 11 | } 12 | 13 | #[derive(Debug, Clone, PartialEq, Eq)] 14 | pub struct ServerMessage { 15 | pub status_code: Option, 16 | pub subscription_id: SubscriptionId, 17 | pub base: MessageBase, 18 | } 19 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/client.rs: -------------------------------------------------------------------------------- 1 | use alloc::boxed::Box; 2 | use core::num::NonZero; 3 | 4 | use crate::{ 5 | Subject, connect::Connect, message::MessageBase, queue_group::QueueGroup, 6 | subscription_id::SubscriptionId, 7 | }; 8 | 9 | #[derive(Debug)] 10 | pub enum ClientOp { 11 | Connect { 12 | connect: Box, 13 | }, 14 | Publish { 15 | message: MessageBase, 16 | }, 17 | Subscribe { 18 | id: SubscriptionId, 19 | subject: Subject, 20 | queue_group: Option, 21 | }, 22 | Unsubscribe { 23 | id: SubscriptionId, 24 | max_messages: Option>, 25 | }, 26 | Ping, 27 | Pong, 28 | } 29 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/decoder/framed.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | 3 | use crate::{proto::ServerOp, util::CrlfFinder}; 4 | 5 | use super::{DecoderError, DecoderStatus}; 6 | 7 | /// Decodes a frame of bytes into a [`ServerOp`]. 8 | /// 9 | /// # Errors 10 | /// 11 | /// It returns an error in case the frame is incomplete or if a decoding error occurs. 12 | pub fn decode_frame(frame: &mut Bytes) -> Result { 13 | let mut status = DecoderStatus::ControlLine { last_bytes_read: 0 }; 14 | match super::decode(&CrlfFinder::new(), &mut status, frame) { 15 | Ok(Some(server_op)) => Ok(server_op), 16 | Ok(None) => Err(FrameDecoderError::IncompleteFrame), 17 | Err(err) => Err(FrameDecoderError::Decoder(err)), 18 | } 19 | } 20 | 21 | #[derive(Debug, thiserror::Error)] 22 | pub enum FrameDecoderError { 23 | #[error("incomplete frame")] 24 | IncompleteFrame, 25 | #[error("decoder error")] 26 | Decoder(#[source] DecoderError), 27 | } 28 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/decoder/stream.rs: -------------------------------------------------------------------------------- 1 | use bytes::{BufMut, BytesMut}; 2 | 3 | use crate::{ 4 | proto::{ServerOp, error::DecoderError}, 5 | util::CrlfFinder, 6 | }; 7 | 8 | use super::DecoderStatus; 9 | 10 | const INITIAL_READ_BUF_CAPACITY: usize = 64 * 1024; 11 | 12 | #[derive(Debug)] 13 | pub struct StreamDecoder { 14 | read_buf: BytesMut, 15 | status: DecoderStatus, 16 | crlf: CrlfFinder, 17 | } 18 | 19 | impl StreamDecoder { 20 | #[must_use] 21 | pub fn new() -> Self { 22 | Self { 23 | read_buf: BytesMut::with_capacity(INITIAL_READ_BUF_CAPACITY), 24 | status: DecoderStatus::ControlLine { last_bytes_read: 0 }, 25 | crlf: CrlfFinder::new(), 26 | } 27 | } 28 | 29 | #[must_use] 30 | pub fn read_buf(&mut self) -> &mut (impl BufMut + use<>) { 31 | &mut self.read_buf 32 | } 33 | 34 | /// Decodes the next frame of bytes into a [`ServerOp`]. 35 | /// 36 | /// A `None` variant is returned in case no progress is made, 37 | /// 38 | /// # Errors 39 | /// 40 | /// It returns an error if a decoding error occurs. 41 | pub fn decode(&mut self) -> Result, DecoderError> { 42 | super::decode(&self.crlf, &mut self.status, &mut self.read_buf) 43 | } 44 | } 45 | 46 | impl Default for StreamDecoder { 47 | fn default() -> Self { 48 | Self::new() 49 | } 50 | } 51 | 52 | #[cfg(test)] 53 | mod tests { 54 | use bytes::{BufMut as _, Bytes}; 55 | use claims::{assert_matches, assert_ok_eq}; 56 | 57 | use crate::{ 58 | Subject, 59 | error::ServerError, 60 | headers::HeaderMap, 61 | message::{MessageBase, ServerMessage}, 62 | proto::{error::DecoderError, server::ServerOp}, 63 | }; 64 | 65 | use super::StreamDecoder; 66 | 67 | #[test] 68 | fn decode_ping() { 69 | let mut decoder = StreamDecoder::new(); 70 | decoder.read_buf().put_slice(b"PING\r\n"); 71 | assert_ok_eq!(decoder.decode(), Some(ServerOp::Ping)); 72 | assert_ok_eq!(decoder.decode(), None); 73 | } 74 | 75 | #[test] 76 | fn decode_pong() { 77 | let mut decoder = StreamDecoder::new(); 78 | decoder.read_buf().put_slice(b"PONG\r\n"); 79 | assert_ok_eq!(decoder.decode(), Some(ServerOp::Pong)); 80 | assert_ok_eq!(decoder.decode(), None); 81 | } 82 | 83 | #[test] 84 | fn decode_ok() { 85 | let mut decoder = StreamDecoder::new(); 86 | decoder.read_buf().put_slice(b"+OK\r\n"); 87 | assert_ok_eq!(decoder.decode(), Some(ServerOp::Success)); 88 | assert_ok_eq!(decoder.decode(), None); 89 | } 90 | 91 | #[test] 92 | fn decode_error() { 93 | let mut decoder = StreamDecoder::new(); 94 | decoder 95 | .read_buf() 96 | .put_slice(b"-ERR 'Authorization Violation'\r\n"); 97 | assert_ok_eq!( 98 | decoder.decode(), 99 | Some(ServerOp::Error { 100 | error: ServerError::AuthorizationViolation 101 | }) 102 | ); 103 | assert_ok_eq!(decoder.decode(), None); 104 | } 105 | 106 | #[test] 107 | fn decode_msg() { 108 | let mut decoder = StreamDecoder::new(); 109 | decoder 110 | .read_buf() 111 | .put_slice(b"MSG hello.world 1 12\r\nHello World!\r\n"); 112 | assert_ok_eq!( 113 | decoder.decode(), 114 | Some(ServerOp::Message { 115 | message: ServerMessage { 116 | status_code: None, 117 | subscription_id: 1.into(), 118 | base: MessageBase { 119 | subject: Subject::from_static("hello.world"), 120 | reply_subject: None, 121 | headers: HeaderMap::new(), 122 | payload: Bytes::from_static(b"Hello World!") 123 | } 124 | } 125 | }) 126 | ); 127 | assert_ok_eq!(decoder.decode(), None); 128 | } 129 | 130 | #[test] 131 | fn head_too_long() { 132 | let mut decoder = StreamDecoder::new(); 133 | decoder.read_buf().put_bytes(0, 20000); 134 | assert_matches!( 135 | decoder.decode(), 136 | Err(DecoderError::HeadTooLong { len: 20000 }) 137 | ); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/encoder/framed.rs: -------------------------------------------------------------------------------- 1 | use bytes::BytesMut; 2 | 3 | use crate::proto::ClientOp; 4 | 5 | use super::FrameEncoder; 6 | 7 | #[derive(Debug)] 8 | pub struct FramedEncoder { 9 | buf: BytesMut, 10 | } 11 | 12 | impl FramedEncoder { 13 | #[must_use] 14 | pub fn new() -> Self { 15 | Self { 16 | buf: BytesMut::new(), 17 | } 18 | } 19 | 20 | pub fn encode(&mut self, item: &ClientOp) -> BytesMut { 21 | struct Encoder<'a>(&'a mut FramedEncoder); 22 | 23 | impl FrameEncoder for Encoder<'_> { 24 | fn small_write(&mut self, buf: &[u8]) { 25 | self.0.buf.extend_from_slice(buf); 26 | } 27 | } 28 | 29 | super::encode(&mut Encoder(self), item); 30 | self.buf.split() 31 | } 32 | } 33 | 34 | impl Default for FramedEncoder { 35 | fn default() -> Self { 36 | Self::new() 37 | } 38 | } 39 | 40 | #[cfg(test)] 41 | mod tests { 42 | use core::num::NonZero; 43 | 44 | use bytes::Bytes; 45 | 46 | use super::FramedEncoder; 47 | use crate::{ 48 | MessageBase, QueueGroup, Subject, 49 | headers::{HeaderMap, HeaderName, HeaderValue}, 50 | proto::ClientOp, 51 | tests::ToBytes as _, 52 | }; 53 | 54 | #[test] 55 | fn encode_ping() { 56 | let mut encoder = FramedEncoder::new(); 57 | assert_eq!( 58 | encoder.encode(&ClientOp::Ping).to_bytes(), 59 | "PING\r\n".as_bytes() 60 | ); 61 | } 62 | 63 | #[test] 64 | fn encode_pong() { 65 | let mut encoder = FramedEncoder::new(); 66 | assert_eq!( 67 | encoder.encode(&ClientOp::Pong).to_bytes(), 68 | "PONG\r\n".as_bytes() 69 | ); 70 | } 71 | 72 | #[test] 73 | fn encode_subscribe() { 74 | let mut encoder = FramedEncoder::new(); 75 | assert_eq!( 76 | encoder 77 | .encode(&ClientOp::Subscribe { 78 | id: 1.into(), 79 | subject: Subject::from_static("hello.world"), 80 | queue_group: None, 81 | }) 82 | .to_bytes(), 83 | "SUB hello.world 1\r\n".as_bytes() 84 | ); 85 | } 86 | 87 | #[test] 88 | fn encode_subscribe_with_queue_group() { 89 | let mut encoder = FramedEncoder::new(); 90 | assert_eq!( 91 | encoder 92 | .encode(&ClientOp::Subscribe { 93 | id: 1.into(), 94 | subject: Subject::from_static("hello.world"), 95 | queue_group: Some(QueueGroup::from_static("stuff")), 96 | }) 97 | .to_bytes(), 98 | "SUB hello.world stuff 1\r\n".as_bytes() 99 | ); 100 | } 101 | 102 | #[test] 103 | fn encode_unsubscribe() { 104 | let mut encoder = FramedEncoder::new(); 105 | assert_eq!( 106 | encoder 107 | .encode(&ClientOp::Unsubscribe { 108 | id: 1.into(), 109 | max_messages: None, 110 | }) 111 | .to_bytes(), 112 | "UNSUB 1\r\n".as_bytes() 113 | ); 114 | } 115 | 116 | #[test] 117 | fn encode_unsubscribe_with_max_messages() { 118 | let mut encoder = FramedEncoder::new(); 119 | assert_eq!( 120 | encoder 121 | .encode(&ClientOp::Unsubscribe { 122 | id: 1.into(), 123 | max_messages: Some(NonZero::new(5).unwrap()), 124 | }) 125 | .to_bytes(), 126 | "UNSUB 1 5\r\n".as_bytes() 127 | ); 128 | } 129 | 130 | #[test] 131 | fn encode_publish() { 132 | let mut encoder = FramedEncoder::new(); 133 | assert_eq!( 134 | encoder 135 | .encode(&ClientOp::Publish { 136 | message: MessageBase { 137 | subject: Subject::from_static("hello.world"), 138 | reply_subject: None, 139 | headers: HeaderMap::new(), 140 | payload: Bytes::from_static(b"Hello World!"), 141 | }, 142 | }) 143 | .to_bytes(), 144 | "PUB hello.world 12\r\nHello World!\r\n".as_bytes() 145 | ); 146 | } 147 | 148 | #[test] 149 | fn encode_publish_with_reply_subject() { 150 | let mut encoder = FramedEncoder::new(); 151 | assert_eq!( 152 | encoder 153 | .encode(&ClientOp::Publish { 154 | message: MessageBase { 155 | subject: Subject::from_static("hello.world"), 156 | reply_subject: Some(Subject::from_static("_INBOX.1234")), 157 | headers: HeaderMap::new(), 158 | payload: Bytes::from_static(b"Hello World!"), 159 | }, 160 | }) 161 | .to_bytes(), 162 | "PUB hello.world _INBOX.1234 12\r\nHello World!\r\n".as_bytes() 163 | ); 164 | } 165 | 166 | #[test] 167 | fn encode_publish_with_headers() { 168 | let mut encoder = FramedEncoder::new(); 169 | assert_eq!( 170 | encoder.encode(&ClientOp::Publish { 171 | message: MessageBase { 172 | subject: Subject::from_static("hello.world"), 173 | reply_subject: None, 174 | headers: [ 175 | ( 176 | HeaderName::from_static("Nats-Message-Id"), 177 | HeaderValue::from_static("abcd"), 178 | ), 179 | ( 180 | HeaderName::from_static("Nats-Sequence"), 181 | HeaderValue::from_static("1"), 182 | ), 183 | ] 184 | .into_iter() 185 | .collect(), 186 | payload: Bytes::from_static(b"Hello World!"), 187 | }, 188 | }).to_bytes(), 189 | "HPUB hello.world 53 65\r\nNATS/1.0\r\nNats-Message-Id: abcd\r\nNats-Sequence: 1\r\n\r\nHello World!\r\n".as_bytes() 190 | ); 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/encoder/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "std")] 2 | use std::io; 3 | 4 | use bytes::Bytes; 5 | 6 | use crate::MessageBase; 7 | use crate::headers::HeaderMap; 8 | 9 | pub use self::framed::FramedEncoder; 10 | pub use self::stream::StreamEncoder; 11 | 12 | use super::ClientOp; 13 | 14 | mod framed; 15 | mod stream; 16 | 17 | pub(super) trait FrameEncoder { 18 | fn small_write(&mut self, buf: &[u8]); 19 | 20 | fn write(&mut self, buf: B) 21 | where 22 | B: Into + AsRef<[u8]>, 23 | { 24 | self.small_write(buf.as_ref()); 25 | } 26 | 27 | #[cfg(feature = "std")] 28 | fn small_io_writer(&mut self) -> SmallIoWriter<'_, Self> { 29 | SmallIoWriter(self) 30 | } 31 | } 32 | 33 | #[cfg(feature = "std")] 34 | pub(super) struct SmallIoWriter<'a, E: ?Sized>(&'a mut E); 35 | 36 | #[cfg(feature = "std")] 37 | impl io::Write for SmallIoWriter<'_, E> 38 | where 39 | E: FrameEncoder, 40 | { 41 | fn write(&mut self, buf: &[u8]) -> io::Result { 42 | self.0.small_write(buf); 43 | Ok(buf.len()) 44 | } 45 | 46 | fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { 47 | self.0.small_write(buf); 48 | Ok(()) 49 | } 50 | 51 | fn flush(&mut self) -> io::Result<()> { 52 | Ok(()) 53 | } 54 | } 55 | 56 | pub(super) fn encode(encoder: &mut E, item: &ClientOp) { 57 | match item { 58 | ClientOp::Publish { message } => { 59 | let MessageBase { 60 | subject, 61 | reply_subject, 62 | headers, 63 | payload, 64 | } = &message; 65 | let verb_and_space = if headers.is_empty() { "PUB " } else { "HPUB " }; 66 | encoder.small_write(verb_and_space.as_bytes()); 67 | encoder.small_write(subject.as_bytes()); 68 | encoder.small_write(b" "); 69 | 70 | if let Some(reply_subject) = reply_subject { 71 | encoder.small_write(reply_subject.as_bytes()); 72 | encoder.small_write(b" "); 73 | } 74 | 75 | let mut buffer = itoa::Buffer::new(); 76 | if headers.is_empty() { 77 | encoder.small_write(buffer.format(payload.len()).as_bytes()); 78 | encoder.small_write(b"\r\n"); 79 | } else { 80 | let headers_len = encode_headers(headers).fold(0, |len, s| len + s.len()); 81 | encoder.small_write(buffer.format(headers_len).as_bytes()); 82 | encoder.small_write(b" "); 83 | 84 | let total_len = headers_len + payload.len(); 85 | encoder.small_write(buffer.format(total_len).as_bytes()); 86 | encoder.small_write(b"\r\n"); 87 | 88 | encode_headers(headers).for_each(|s| { 89 | encoder.small_write(s.as_bytes()); 90 | }); 91 | } 92 | 93 | encoder.write(IntoBytes(payload)); 94 | encoder.small_write(b"\r\n"); 95 | } 96 | ClientOp::Subscribe { 97 | id, 98 | subject, 99 | queue_group, 100 | } => { 101 | // `SUB {subject} [{queue_group} ]id\r\n` 102 | encoder.small_write(b"SUB "); 103 | encoder.small_write(subject.as_bytes()); 104 | encoder.small_write(b" "); 105 | 106 | if let Some(queue_group) = queue_group { 107 | encoder.small_write(queue_group.as_bytes()); 108 | encoder.small_write(b" "); 109 | } 110 | 111 | let mut buffer = itoa::Buffer::new(); 112 | encoder.small_write(buffer.format(u64::from(*id)).as_bytes()); 113 | encoder.small_write(b"\r\n"); 114 | } 115 | ClientOp::Unsubscribe { id, max_messages } => { 116 | // `UNSUB {id}[ {max_messages}]\r\n` 117 | encoder.small_write(b"UNSUB "); 118 | 119 | let mut buffer = itoa::Buffer::new(); 120 | encoder.small_write(buffer.format(u64::from(*id)).as_bytes()); 121 | 122 | if let Some(max_messages) = *max_messages { 123 | encoder.small_write(b" "); 124 | encoder.small_write(buffer.format(max_messages.get()).as_bytes()); 125 | } 126 | 127 | encoder.small_write(b"\r\n"); 128 | } 129 | ClientOp::Connect { connect } => { 130 | encoder.small_write(b"CONNECT "); 131 | #[cfg(feature = "std")] 132 | serde_json::to_writer(encoder.small_io_writer(), &connect) 133 | .expect("serialize `Connect`"); 134 | #[cfg(not(feature = "std"))] 135 | encoder.write(serde_json::to_vec(&connect).expect("serialize `Connect`")); 136 | encoder.small_write(b"\r\n"); 137 | } 138 | ClientOp::Ping => { 139 | encoder.small_write(b"PING\r\n"); 140 | } 141 | ClientOp::Pong => { 142 | encoder.small_write(b"PONG\r\n"); 143 | } 144 | } 145 | } 146 | 147 | struct IntoBytes<'a>(&'a Bytes); 148 | 149 | impl<'a> From> for Bytes { 150 | fn from(value: IntoBytes<'a>) -> Self { 151 | Bytes::clone(value.0) 152 | } 153 | } 154 | 155 | impl AsRef<[u8]> for IntoBytes<'_> { 156 | fn as_ref(&self) -> &[u8] { 157 | self.0 158 | } 159 | } 160 | 161 | fn encode_headers(headers: &HeaderMap) -> impl Iterator { 162 | let head = ["NATS/1.0\r\n"]; 163 | let headers = headers.iter().flat_map(|(name, values)| { 164 | values.flat_map(|value| [name.as_str(), ": ", value.as_str(), "\r\n"]) 165 | }); 166 | let footer = ["\r\n"]; 167 | 168 | head.into_iter().chain(headers).chain(footer) 169 | } 170 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::client::ClientOp; 2 | pub use self::decoder::{StreamDecoder, decode_frame}; 3 | pub use self::encoder::{FramedEncoder, StreamEncoder}; 4 | pub use self::server::ServerOp; 5 | 6 | mod client; 7 | mod decoder; 8 | mod encoder; 9 | mod server; 10 | 11 | pub mod error { 12 | pub use super::decoder::{DecoderError, FrameDecoderError}; 13 | } 14 | -------------------------------------------------------------------------------- /watermelon-proto/src/proto/server.rs: -------------------------------------------------------------------------------- 1 | use alloc::boxed::Box; 2 | 3 | use crate::{ServerInfo, error::ServerError, message::ServerMessage}; 4 | 5 | #[derive(Debug, PartialEq, Eq)] 6 | pub enum ServerOp { 7 | Info { info: Box }, 8 | Message { message: ServerMessage }, 9 | Success, 10 | Error { error: ServerError }, 11 | Ping, 12 | Pong, 13 | } 14 | -------------------------------------------------------------------------------- /watermelon-proto/src/queue_group.rs: -------------------------------------------------------------------------------- 1 | use alloc::string::String; 2 | use core::{ 3 | fmt::{self, Display}, 4 | ops::Deref, 5 | }; 6 | use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; 7 | 8 | use bytestring::ByteString; 9 | 10 | /// A string that can be used to represent an queue group 11 | /// 12 | /// `QueueGroup` contains a string that is guaranteed [^1] to 13 | /// contain a valid header name that meets the following requirements: 14 | /// 15 | /// * The value is not empty 16 | /// * The value has a length less than or equal to 64 [^2] 17 | /// * The value does not contain any whitespace characters or `:` 18 | /// 19 | /// `QueueGroup` can be constructed from [`QueueGroup::from_static`] 20 | /// or any of the `TryFrom` implementations. 21 | /// 22 | /// [^1]: Because [`QueueGroup::from_dangerous_value`] is safe to call, 23 | /// unsafe code must not assume any of the above invariants. 24 | /// [^2]: Messages coming from the NATS server are allowed to violate this rule. 25 | #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] 26 | pub struct QueueGroup(ByteString); 27 | 28 | impl QueueGroup { 29 | /// Construct `QueueGroup` from a static string 30 | /// 31 | /// # Panics 32 | /// 33 | /// Will panic if `value` isn't a valid `QueueGroup` 34 | #[must_use] 35 | pub fn from_static(value: &'static str) -> Self { 36 | Self::try_from(ByteString::from_static(value)).expect("invalid QueueGroup") 37 | } 38 | 39 | /// Construct a `QueueGroup` from a string, without checking invariants 40 | /// 41 | /// This method bypasses invariants checks implemented by [`QueueGroup::from_static`] 42 | /// and all `TryFrom` implementations. 43 | /// 44 | /// # Security 45 | /// 46 | /// While calling this method can eliminate the runtime performance cost of 47 | /// checking the string, constructing `QueueGroup` with an invalid string and 48 | /// then calling the NATS server with it can cause serious security issues. 49 | /// When in doubt use the [`QueueGroup::from_static`] or any of the `TryFrom` 50 | /// implementations. 51 | #[must_use] 52 | #[expect( 53 | clippy::missing_panics_doc, 54 | reason = "The queue group validation is only made in debug" 55 | )] 56 | pub fn from_dangerous_value(value: ByteString) -> Self { 57 | if cfg!(debug_assertions) { 58 | if let Err(err) = validate_queue_group(&value) { 59 | panic!("QueueGroup {value:?} isn't valid {err:?}"); 60 | } 61 | } 62 | Self(value) 63 | } 64 | 65 | #[must_use] 66 | pub fn as_str(&self) -> &str { 67 | &self.0 68 | } 69 | } 70 | 71 | impl Display for QueueGroup { 72 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 73 | Display::fmt(&self.0, f) 74 | } 75 | } 76 | 77 | impl TryFrom for QueueGroup { 78 | type Error = QueueGroupValidateError; 79 | 80 | fn try_from(value: ByteString) -> Result { 81 | validate_queue_group(&value)?; 82 | Ok(Self::from_dangerous_value(value)) 83 | } 84 | } 85 | 86 | impl TryFrom for QueueGroup { 87 | type Error = QueueGroupValidateError; 88 | 89 | fn try_from(value: String) -> Result { 90 | validate_queue_group(&value)?; 91 | Ok(Self::from_dangerous_value(value.into())) 92 | } 93 | } 94 | 95 | impl From for ByteString { 96 | fn from(value: QueueGroup) -> Self { 97 | value.0 98 | } 99 | } 100 | 101 | impl AsRef<[u8]> for QueueGroup { 102 | fn as_ref(&self) -> &[u8] { 103 | self.as_str().as_bytes() 104 | } 105 | } 106 | 107 | impl AsRef for QueueGroup { 108 | fn as_ref(&self) -> &str { 109 | self.as_str() 110 | } 111 | } 112 | 113 | impl Deref for QueueGroup { 114 | type Target = str; 115 | 116 | fn deref(&self) -> &Self::Target { 117 | self.as_str() 118 | } 119 | } 120 | 121 | impl Serialize for QueueGroup { 122 | fn serialize(&self, serializer: S) -> Result { 123 | self.as_str().serialize(serializer) 124 | } 125 | } 126 | 127 | impl<'de> Deserialize<'de> for QueueGroup { 128 | fn deserialize>(deserializer: D) -> Result { 129 | let s = ByteString::deserialize(deserializer)?; 130 | s.try_into().map_err(de::Error::custom) 131 | } 132 | } 133 | 134 | /// An error encountered while validating [`QueueGroup`] 135 | #[derive(Debug, thiserror::Error)] 136 | #[cfg_attr(test, derive(PartialEq, Eq))] 137 | pub enum QueueGroupValidateError { 138 | /// The value is empty 139 | #[error("QueueGroup is empty")] 140 | Empty, 141 | /// The value has a length greater than 64 142 | #[error("QueueGroup is too long")] 143 | TooLong, 144 | /// The value contains an Unicode whitespace character 145 | #[error("QueueGroup contained an illegal whitespace character")] 146 | IllegalCharacter, 147 | } 148 | 149 | fn validate_queue_group(queue_group: &str) -> Result<(), QueueGroupValidateError> { 150 | if queue_group.is_empty() { 151 | return Err(QueueGroupValidateError::Empty); 152 | } 153 | 154 | if queue_group.len() > 64 { 155 | // This is an arbitrary limit, but I guess the server must also have one 156 | return Err(QueueGroupValidateError::TooLong); 157 | } 158 | 159 | if queue_group.chars().any(char::is_whitespace) { 160 | // The theoretical security limit is just ` `, `\t`, `\r` and `\n`. 161 | // Let's be more careful. 162 | return Err(QueueGroupValidateError::IllegalCharacter); 163 | } 164 | 165 | Ok(()) 166 | } 167 | 168 | #[cfg(test)] 169 | mod tests { 170 | use bytestring::ByteString; 171 | 172 | use super::{QueueGroup, QueueGroupValidateError}; 173 | 174 | #[test] 175 | fn valid_queue_groups() { 176 | let queue_groups = ["importer", "importer.thing", "blablabla:itworks"]; 177 | for queue_group in queue_groups { 178 | let q = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap(); 179 | assert_eq!(queue_group, q.as_str()); 180 | } 181 | } 182 | 183 | #[test] 184 | fn invalid_queue_groups() { 185 | let queue_groups = [ 186 | ("", QueueGroupValidateError::Empty), 187 | ( 188 | "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 189 | QueueGroupValidateError::TooLong, 190 | ), 191 | ("importer ", QueueGroupValidateError::IllegalCharacter), 192 | ("importer .thing", QueueGroupValidateError::IllegalCharacter), 193 | (" importer", QueueGroupValidateError::IllegalCharacter), 194 | ("importer.thing ", QueueGroupValidateError::IllegalCharacter), 195 | ( 196 | "importer.thing.works ", 197 | QueueGroupValidateError::IllegalCharacter, 198 | ), 199 | ( 200 | "importer.thing.works\r", 201 | QueueGroupValidateError::IllegalCharacter, 202 | ), 203 | ( 204 | "importer.thing.works\n", 205 | QueueGroupValidateError::IllegalCharacter, 206 | ), 207 | ( 208 | "importer.thing.works\t", 209 | QueueGroupValidateError::IllegalCharacter, 210 | ), 211 | ( 212 | "importer.thi ng.works", 213 | QueueGroupValidateError::IllegalCharacter, 214 | ), 215 | ( 216 | "importer.thi\rng.works", 217 | QueueGroupValidateError::IllegalCharacter, 218 | ), 219 | ( 220 | "importer.thi\nng.works", 221 | QueueGroupValidateError::IllegalCharacter, 222 | ), 223 | ( 224 | "importer.thi\tng.works", 225 | QueueGroupValidateError::IllegalCharacter, 226 | ), 227 | ( 228 | "importer.thing .works", 229 | QueueGroupValidateError::IllegalCharacter, 230 | ), 231 | ( 232 | "importer.thing\r.works", 233 | QueueGroupValidateError::IllegalCharacter, 234 | ), 235 | ( 236 | "importer.thing\n.works", 237 | QueueGroupValidateError::IllegalCharacter, 238 | ), 239 | ( 240 | "importer.thing\t.works", 241 | QueueGroupValidateError::IllegalCharacter, 242 | ), 243 | (" ", QueueGroupValidateError::IllegalCharacter), 244 | ("\r", QueueGroupValidateError::IllegalCharacter), 245 | ("\n", QueueGroupValidateError::IllegalCharacter), 246 | ("\t", QueueGroupValidateError::IllegalCharacter), 247 | ]; 248 | for (queue_group, expected_err) in queue_groups { 249 | let err = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap_err(); 250 | assert_eq!(expected_err, err); 251 | } 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /watermelon-proto/src/server_error.rs: -------------------------------------------------------------------------------- 1 | use bytestring::ByteString; 2 | 3 | #[derive(Debug, PartialEq, Eq, thiserror::Error)] 4 | pub enum ServerError { 5 | #[error("subject is invalid")] 6 | InvalidSubject, 7 | #[error("permissions violation for publish")] 8 | PublishPermissionViolation, 9 | #[error("permissions violation for subscription")] 10 | SubscribePermissionViolation, 11 | 12 | #[error("unknown protocol operation")] 13 | UnknownProtocolOperation, 14 | 15 | #[error("attempted to connect to route port")] 16 | ConnectionAttemptedToWrongPort, 17 | 18 | #[error("authorization violation")] 19 | AuthorizationViolation, 20 | #[error("authorization timeout")] 21 | AuthorizationTimeout, 22 | #[error("invalid client protocol")] 23 | InvalidClientProtocol, 24 | #[error("maximum control line exceeded")] 25 | MaximumControlLineExceeded, 26 | #[error("parser error")] 27 | ParseError, 28 | #[error("secure connection, tls required")] 29 | TlsRequired, 30 | #[error("stale connection")] 31 | StaleConnection, 32 | #[error("maximum connections exceeded")] 33 | MaximumConnectionsExceeded, 34 | #[error("slow consumer")] 35 | SlowConsumer, 36 | #[error("maximum payload violation")] 37 | MaximumPayloadViolation, 38 | 39 | #[error("unknown error: {raw_message}")] 40 | Other { raw_message: ByteString }, 41 | } 42 | 43 | impl ServerError { 44 | pub fn is_fatal(&self) -> Option { 45 | match self { 46 | Self::InvalidSubject 47 | | Self::PublishPermissionViolation 48 | | Self::SubscribePermissionViolation => Some(false), 49 | 50 | Self::UnknownProtocolOperation 51 | | Self::ConnectionAttemptedToWrongPort 52 | | Self::AuthorizationViolation 53 | | Self::AuthorizationTimeout 54 | | Self::InvalidClientProtocol 55 | | Self::MaximumControlLineExceeded 56 | | Self::ParseError 57 | | Self::TlsRequired 58 | | Self::StaleConnection 59 | | Self::MaximumConnectionsExceeded 60 | | Self::SlowConsumer 61 | | Self::MaximumPayloadViolation => Some(true), 62 | 63 | Self::Other { .. } => None, 64 | } 65 | } 66 | 67 | pub(crate) fn parse(raw_message: ByteString) -> Self { 68 | const PUBLISH_PERMISSIONS: &str = "Permissions Violation for Publish"; 69 | const SUBSCRIPTION_PERMISSIONS: &str = "Permissions Violation for Subscription"; 70 | 71 | let m = raw_message.trim(); 72 | if m.eq_ignore_ascii_case("Invalid Subject") { 73 | Self::InvalidSubject 74 | } else if m.len() > PUBLISH_PERMISSIONS.len() 75 | && m[..PUBLISH_PERMISSIONS.len()].eq_ignore_ascii_case(PUBLISH_PERMISSIONS) 76 | { 77 | Self::PublishPermissionViolation 78 | } else if m.len() > SUBSCRIPTION_PERMISSIONS.len() 79 | && m[..SUBSCRIPTION_PERMISSIONS.len()].eq_ignore_ascii_case(SUBSCRIPTION_PERMISSIONS) 80 | { 81 | Self::SubscribePermissionViolation 82 | } else if m.eq_ignore_ascii_case("Unknown Protocol Operation") { 83 | Self::UnknownProtocolOperation 84 | } else if m.eq_ignore_ascii_case("Attempted To Connect To Route Port") { 85 | Self::ConnectionAttemptedToWrongPort 86 | } else if m.eq_ignore_ascii_case("Authorization Violation") { 87 | Self::AuthorizationViolation 88 | } else if m.eq_ignore_ascii_case("Authorization Timeout") { 89 | Self::AuthorizationTimeout 90 | } else if m.eq_ignore_ascii_case("Invalid Client Protocol") { 91 | Self::InvalidClientProtocol 92 | } else if m.eq_ignore_ascii_case("Maximum Control Line Exceeded") { 93 | Self::MaximumControlLineExceeded 94 | } else if m.eq_ignore_ascii_case("Parser Error") { 95 | Self::ParseError 96 | } else if m.eq_ignore_ascii_case("Secure Connection - TLS Required") { 97 | Self::TlsRequired 98 | } else if m.eq_ignore_ascii_case("Stale Connection") { 99 | Self::StaleConnection 100 | } else if m.eq_ignore_ascii_case("Maximum Connections Exceeded") { 101 | Self::MaximumConnectionsExceeded 102 | } else if m.eq_ignore_ascii_case("Slow Consumer") { 103 | Self::SlowConsumer 104 | } else if m.eq_ignore_ascii_case("Maximum Payload Violation") { 105 | Self::MaximumPayloadViolation 106 | } else { 107 | Self::Other { raw_message } 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /watermelon-proto/src/server_info.rs: -------------------------------------------------------------------------------- 1 | use alloc::{string::String, vec::Vec}; 2 | use core::{net::IpAddr, num::NonZero}; 3 | 4 | use serde::Deserialize; 5 | 6 | use crate::ServerAddr; 7 | 8 | #[derive(Debug, PartialEq, Eq, Deserialize)] 9 | #[allow(clippy::struct_excessive_bools)] 10 | pub struct ServerInfo { 11 | #[serde(rename = "server_id")] 12 | pub id: String, 13 | #[serde(rename = "server_name")] 14 | pub name: String, 15 | pub version: String, 16 | #[serde(rename = "go")] 17 | pub go_version: String, 18 | pub host: IpAddr, 19 | pub port: NonZero, 20 | #[serde(rename = "headers")] 21 | pub supports_headers: bool, 22 | pub max_payload: NonZero, 23 | #[serde(rename = "proto")] 24 | pub protocol_version: u32, 25 | #[serde(default)] 26 | pub client_id: Option, 27 | #[serde(default)] 28 | pub auth_required: bool, 29 | #[serde(default)] 30 | pub tls_required: bool, 31 | #[serde(default)] 32 | pub tls_verify: bool, 33 | #[serde(default)] 34 | pub tls_available: bool, 35 | #[serde(default)] 36 | pub connect_urls: Vec, 37 | #[serde(default, rename = "ws_connect_urls")] 38 | pub websocket_connect_urls: Vec, 39 | #[serde(default, rename = "ldm")] 40 | pub lame_duck_mode: bool, 41 | #[serde(default)] 42 | pub git_commit: Option, 43 | #[serde(default, rename = "jetstream")] 44 | pub supports_jetstream: bool, 45 | #[serde(default)] 46 | pub ip: Option, 47 | #[serde(default)] 48 | pub client_ip: Option, 49 | #[serde(default)] 50 | pub nonce: Option, 51 | #[serde(default, rename = "cluster")] 52 | pub cluster_name: Option, 53 | #[serde(default)] 54 | pub domain: Option, 55 | 56 | #[serde(flatten)] 57 | pub non_standard: NonStandardServerInfo, 58 | } 59 | 60 | #[derive(Debug, PartialEq, Eq, Deserialize, Default)] 61 | #[non_exhaustive] 62 | pub struct NonStandardServerInfo { 63 | #[cfg(feature = "non-standard-zstd")] 64 | #[serde(default, rename = "m4ss_zstd")] 65 | pub zstd: bool, 66 | } 67 | -------------------------------------------------------------------------------- /watermelon-proto/src/status_code.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | fmt::{self, Display, Formatter}, 3 | num::NonZero, 4 | str::FromStr, 5 | }; 6 | 7 | use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; 8 | 9 | use crate::util; 10 | 11 | /// A NATS status code 12 | /// 13 | /// Constants are provided for known and accurately status codes 14 | /// within the NATS Server. 15 | /// 16 | /// Values are guaranteed to be in range `100..1000`. 17 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] 18 | pub struct StatusCode(NonZero); 19 | 20 | impl StatusCode { 21 | /// The Jetstream consumer hearthbeat timeout has been reached with no new messages to deliver 22 | /// 23 | /// See [ADR-9]. 24 | /// 25 | /// [ADR-9]: https://github.com/nats-io/nats-architecture-and-design/blob/main/adr/ADR-9.md 26 | pub const IDLE_HEARTBEAT: StatusCode = Self::new_internal(100); 27 | /// The request has successfully been sent 28 | pub const OK: StatusCode = Self::new_internal(200); 29 | /// The requested Jetstream resource doesn't exist 30 | pub const NOT_FOUND: StatusCode = Self::new_internal(404); 31 | /// The pull consumer batch reached the timeout 32 | pub const TIMEOUT: StatusCode = Self::new_internal(408); 33 | /// The request was sent to a subject that does not appear to have any subscribers listening 34 | pub const NO_RESPONDERS: StatusCode = Self::new_internal(503); 35 | 36 | /// Decodes a status code from a slice of ASCII characters. 37 | /// 38 | /// The ASCII representation is expected to be in the form of `"NNN"`, where `N` is a numeric 39 | /// digit. 40 | /// 41 | /// # Errors 42 | /// 43 | /// It returns an error if the slice of bytes does not contain a valid status code. 44 | pub fn from_ascii_bytes(buf: &[u8]) -> Result { 45 | if buf.len() != 3 { 46 | return Err(StatusCodeError); 47 | } 48 | 49 | util::parse_u16(buf) 50 | .map_err(|_| StatusCodeError)? 51 | .try_into() 52 | .map(Self) 53 | .map_err(|_| StatusCodeError) 54 | } 55 | 56 | const fn new_internal(val: u16) -> Self { 57 | Self(NonZero::new(val).unwrap()) 58 | } 59 | } 60 | 61 | impl FromStr for StatusCode { 62 | type Err = StatusCodeError; 63 | 64 | fn from_str(s: &str) -> Result { 65 | Self::from_ascii_bytes(s.as_bytes()) 66 | } 67 | } 68 | 69 | impl Display for StatusCode { 70 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 71 | Display::fmt(&self.0, f) 72 | } 73 | } 74 | 75 | impl TryFrom for StatusCode { 76 | type Error = StatusCodeError; 77 | 78 | fn try_from(value: u16) -> Result { 79 | if (100..1000).contains(&value) { 80 | Ok(Self(NonZero::new(value).unwrap())) 81 | } else { 82 | Err(StatusCodeError) 83 | } 84 | } 85 | } 86 | 87 | impl From for u16 { 88 | fn from(value: StatusCode) -> Self { 89 | value.0.get() 90 | } 91 | } 92 | 93 | impl Serialize for StatusCode { 94 | fn serialize(&self, serializer: S) -> Result { 95 | u16::from(*self).serialize(serializer) 96 | } 97 | } 98 | 99 | impl<'de> Deserialize<'de> for StatusCode { 100 | fn deserialize>(deserializer: D) -> Result { 101 | let n = u16::deserialize(deserializer)?; 102 | n.try_into().map_err(de::Error::custom) 103 | } 104 | } 105 | 106 | /// An error encountered while parsing [`StatusCode`] 107 | #[derive(Debug, thiserror::Error)] 108 | #[non_exhaustive] 109 | #[error("invalid status code")] 110 | pub struct StatusCodeError; 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use alloc::string::ToString; 115 | 116 | use claims::assert_err; 117 | 118 | use super::StatusCode; 119 | 120 | #[test] 121 | fn valid_status_codes() { 122 | let status_codes = [100, 200, 404, 408, 409, 503]; 123 | 124 | for status_code in status_codes { 125 | assert_eq!( 126 | status_code, 127 | u16::from(StatusCode::try_from(status_code).unwrap()) 128 | ); 129 | 130 | let s = status_code.to_string(); 131 | assert_eq!( 132 | status_code, 133 | u16::from(StatusCode::from_ascii_bytes(s.as_bytes()).unwrap()) 134 | ); 135 | } 136 | } 137 | 138 | #[test] 139 | fn invalid_status_codes() { 140 | let status_codes = [0, 5, 55, 9999]; 141 | 142 | for status_code in status_codes { 143 | assert_err!(StatusCode::try_from(status_code)); 144 | 145 | let s = status_code.to_string(); 146 | assert_err!(StatusCode::from_ascii_bytes(s.as_bytes())); 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /watermelon-proto/src/subscription_id.rs: -------------------------------------------------------------------------------- 1 | use core::fmt::{self, Display}; 2 | 3 | use crate::util::{self, ParseUintError}; 4 | 5 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] 6 | pub struct SubscriptionId(u64); 7 | 8 | impl SubscriptionId { 9 | pub const MIN: Self = SubscriptionId(1); 10 | pub const MAX: Self = SubscriptionId(u64::MAX); 11 | 12 | /// Converts a slice of ASCII bytes to a `SubscriptionId`. 13 | /// 14 | /// # Errors 15 | /// 16 | /// It returns an error if the bytes do not contain a valid numeric value. 17 | pub fn from_ascii_bytes(buf: &[u8]) -> Result { 18 | util::parse_u64(buf).map(Self) 19 | } 20 | } 21 | 22 | impl From for SubscriptionId { 23 | fn from(value: u64) -> Self { 24 | Self(value) 25 | } 26 | } 27 | 28 | impl From for u64 { 29 | fn from(value: SubscriptionId) -> Self { 30 | value.0 31 | } 32 | } 33 | 34 | impl Display for SubscriptionId { 35 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 36 | Display::fmt(&self.0, f) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /watermelon-proto/src/tests.rs: -------------------------------------------------------------------------------- 1 | use bytes::{Buf, Bytes}; 2 | 3 | pub(crate) trait ToBytes: Buf { 4 | fn to_bytes(mut self) -> Bytes 5 | where 6 | Self: Sized, 7 | { 8 | self.copy_to_bytes(self.remaining()) 9 | } 10 | } 11 | 12 | impl ToBytes for T {} 13 | -------------------------------------------------------------------------------- /watermelon-proto/src/util/crlf.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug)] 2 | pub(crate) struct CrlfFinder(memchr::memmem::Finder<'static>); 3 | 4 | impl CrlfFinder { 5 | pub(crate) fn new() -> Self { 6 | Self(memchr::memmem::Finder::new(b"\r\n")) 7 | } 8 | 9 | pub(crate) fn find(&self, haystack: &[u8]) -> Option { 10 | self.0.find(haystack) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /watermelon-proto/src/util/lines_iter.rs: -------------------------------------------------------------------------------- 1 | use core::{iter, mem}; 2 | 3 | use bytes::{Buf, Bytes}; 4 | 5 | use super::CrlfFinder; 6 | 7 | pub(crate) fn lines_iter(crlf: &CrlfFinder, mut bytes: Bytes) -> impl Iterator + '_ { 8 | iter::from_fn(move || { 9 | if bytes.is_empty() { 10 | return None; 11 | } 12 | 13 | Some(match crlf.find(&bytes) { 14 | Some(i) => { 15 | let chunk = bytes.split_to(i); 16 | bytes.advance("\r\n".len()); 17 | chunk 18 | } 19 | None => mem::take(&mut bytes), 20 | }) 21 | }) 22 | } 23 | 24 | #[cfg(test)] 25 | mod tests { 26 | use bytes::{Bytes, BytesMut}; 27 | 28 | use crate::util::CrlfFinder; 29 | 30 | use super::lines_iter; 31 | 32 | #[test] 33 | fn iterate_lines() { 34 | let expected_chunks = ["", "abcd", "12334534", "alkfdasfsd", "", "-"]; 35 | let mut combined_chunk = expected_chunks 36 | .iter() 37 | .fold(BytesMut::new(), |mut buf, chunk| { 38 | buf.extend_from_slice(chunk.as_bytes()); 39 | buf.extend_from_slice(b"\r\n"); 40 | buf 41 | }); 42 | combined_chunk.truncate(combined_chunk.len() - "\r\n".len()); 43 | let combined_chunk = combined_chunk.freeze(); 44 | 45 | let expected_chunks = expected_chunks 46 | .iter() 47 | .map(|c| Bytes::from_static(c.as_bytes())); 48 | assert!(expected_chunks.eq(lines_iter(&CrlfFinder::new(), combined_chunk))); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /watermelon-proto/src/util/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) use self::buf_list::BufList; 2 | pub(crate) use self::crlf::CrlfFinder; 3 | pub(crate) use self::lines_iter::lines_iter; 4 | pub(crate) use self::split_spaces::split_spaces; 5 | pub use self::uint::ParseUintError; 6 | pub(crate) use self::uint::{parse_u16, parse_u64, parse_usize}; 7 | 8 | mod buf_list; 9 | mod crlf; 10 | mod lines_iter; 11 | mod split_spaces; 12 | mod uint; 13 | -------------------------------------------------------------------------------- /watermelon-proto/src/util/split_spaces.rs: -------------------------------------------------------------------------------- 1 | use core::{iter, mem}; 2 | 3 | use bytes::{Buf, Bytes}; 4 | 5 | pub(crate) fn split_spaces(mut bytes: Bytes) -> impl Iterator { 6 | iter::from_fn(move || { 7 | if bytes.is_empty() { 8 | return None; 9 | } 10 | 11 | let Some(i) = memchr::memchr2(b' ', b'\t', &bytes) else { 12 | return Some(mem::take(&mut bytes)); 13 | }; 14 | 15 | let chunk = bytes.split_to(i); 16 | 17 | let spaces = bytes 18 | .iter() 19 | .take_while(|b| matches!(b, b' ' | b'\t')) 20 | .count(); 21 | debug_assert!(spaces > 0); 22 | bytes.advance(spaces); 23 | Some(chunk) 24 | }) 25 | } 26 | 27 | #[cfg(test)] 28 | mod tests { 29 | use bytes::Bytes; 30 | 31 | use super::split_spaces; 32 | 33 | #[test] 34 | fn combinations() { 35 | let tests: &[(&str, &[&str])] = &[ 36 | ("", &[]), 37 | ("0123456789abcdef", &["0123456789abcdef"]), 38 | ("012345 6789abcdef", &["012345", "6789abcdef"]), 39 | ("012345\t6789abcdef", &["012345", "6789abcdef"]), 40 | ("012345 6789abcdef", &["012345", "6789abcdef"]), 41 | ("012345 6789abcdef", &["012345", "6789abcdef"]), 42 | ("012345\t\t6789abcdef", &["012345", "6789abcdef"]), 43 | ("012345\t\t\t\t6789abcdef", &["012345", "6789abcdef"]), 44 | ("012345 \t \t\t\t 6789abcdef", &["012345", "6789abcdef"]), 45 | ("012345 678 9abcdef", &["012345", "678", "9abcdef"]), 46 | ("012345 678\t9abcdef", &["012345", "678", "9abcdef"]), 47 | ("012345\t678 9abcdef", &["012345", "678", "9abcdef"]), 48 | ("012345\t678\t9abcdef", &["012345", "678", "9abcdef"]), 49 | ("012345\t678\t 9abcdef", &["012345", "678", "9abcdef"]), 50 | ("012345 \t678\t 9abcdef", &["012345", "678", "9abcdef"]), 51 | ]; 52 | 53 | for (input, output) in tests { 54 | let spaces = split_spaces(Bytes::from_static(input.as_bytes())).collect::>(); 55 | assert_eq!(spaces, output.to_vec()); 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /watermelon-proto/src/util/uint.rs: -------------------------------------------------------------------------------- 1 | macro_rules! parse_unsigned { 2 | ($name:ident, $num:ty) => { 3 | pub(crate) fn $name(buf: &[u8]) -> Result<$num, ParseUintError> { 4 | let mut val: $num = 0; 5 | 6 | for &b in buf { 7 | if !b.is_ascii_digit() { 8 | return Err(ParseUintError::InvalidByte(b)); 9 | } 10 | 11 | val = val.checked_mul(10).ok_or(ParseUintError::Overflow)?; 12 | val = val 13 | .checked_add(<$num>::from(b - b'0')) 14 | .ok_or(ParseUintError::Overflow)?; 15 | } 16 | 17 | Ok(val) 18 | } 19 | }; 20 | } 21 | 22 | parse_unsigned!(parse_u16, u16); 23 | parse_unsigned!(parse_u64, u64); 24 | parse_unsigned!(parse_usize, usize); 25 | 26 | #[derive(Debug, thiserror::Error)] 27 | pub enum ParseUintError { 28 | #[error("invalid byte {0:?}")] 29 | InvalidByte(u8), 30 | #[error("overflow")] 31 | Overflow, 32 | } 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | use alloc::string::ToString; 37 | 38 | use claims::assert_ok_eq; 39 | 40 | use super::{parse_u16, parse_u64, parse_usize}; 41 | 42 | #[test] 43 | fn parse_u16_range() { 44 | for n in 0..=u16::MAX { 45 | let s = n.to_string(); 46 | assert_ok_eq!(parse_u16(s.as_bytes()), n); 47 | assert_ok_eq!(parse_usize(s.as_bytes()), usize::from(n)); 48 | assert_ok_eq!(parse_u64(s.as_bytes()), u64::from(n)); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /watermelon/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "watermelon" 3 | version = "0.4.1" 4 | description = "High level actor based implementation NATS Core and NATS Jetstream client implementation" 5 | categories = ["api-bindings", "network-programming"] 6 | keywords = ["nats", "client", "jetstream"] 7 | edition.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | rust-version.workspace = true 11 | 12 | [package.metadata.docs.rs] 13 | features = ["websocket", "non-standard-zstd"] 14 | 15 | [dependencies] 16 | tokio = { version = "1.44", features = ["rt", "sync", "time"] } 17 | arc-swap = "1" 18 | futures-core = "0.3" 19 | bytes = "1" 20 | serde = { version = "1.0.113", features = ["derive"] } 21 | serde_json = "1" 22 | pin-project-lite = "0.2" 23 | jiff = { version = "0.2.1", default-features = false, features = ["serde"] } 24 | 25 | # random number generation 26 | rand = { version = "0.9", default-features = false, features = ["thread_rng"], optional = true } 27 | getrandom = { version = "0.3.1", optional = true } 28 | 29 | # from-env 30 | envy = { version = "0.4", optional = true } 31 | 32 | # portable-atomic 33 | portable-atomic = { version = "1", optional = true } 34 | 35 | watermelon-mini = { version = "0.3.2", path = "../watermelon-mini", default-features = false } 36 | watermelon-net = { version = "0.2", path = "../watermelon-net", default-features = false } 37 | watermelon-proto = { version = "0.1.3", path = "../watermelon-proto" } 38 | watermelon-nkeys = { version = "0.1", path = "../watermelon-nkeys", default-features = false } 39 | 40 | thiserror = "2" 41 | 42 | [dev-dependencies] 43 | tokio = { version = "1.36", features = ["test-util"] } 44 | futures-util = { version = "0.3", default-features = false } 45 | claims = "0.8" 46 | 47 | [features] 48 | default = ["aws-lc-rs", "from-env", "rand"] 49 | websocket = ["watermelon-mini/websocket"] 50 | aws-lc-rs = ["watermelon-mini/aws-lc-rs", "watermelon-nkeys/aws-lc-rs"] 51 | ring = ["watermelon-mini/ring", "watermelon-nkeys/ring"] 52 | fips = ["watermelon-mini/fips", "watermelon-nkeys/fips"] 53 | from-env = ["dep:envy"] 54 | rand = ["dep:rand", "watermelon-mini/rand", "watermelon-net/rand"] 55 | getrandom = ["dep:getrandom", "watermelon-mini/getrandom", "watermelon-net/getrandom"] 56 | portable-atomic = ["dep:portable-atomic"] 57 | non-standard-zstd = ["watermelon-mini/non-standard-zstd", "watermelon-net/non-standard-zstd", "watermelon-proto/non-standard-zstd"] 58 | 59 | [lints] 60 | workspace = true 61 | -------------------------------------------------------------------------------- /watermelon/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /watermelon/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /watermelon/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /watermelon/src/client/commands/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::publish::{ 2 | ClientPublish, DoClientPublish, DoOwnedClientPublish, OwnedClientPublish, Publish, 3 | PublishBuilder, 4 | }; 5 | pub use self::request::{ 6 | ClientRequest, DoClientRequest, DoOwnedClientRequest, OwnedClientRequest, Request, 7 | RequestBuilder, ResponseError, ResponseFut, 8 | }; 9 | 10 | mod publish; 11 | mod request; 12 | -------------------------------------------------------------------------------- /watermelon/src/client/commands/publish.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt::{self, Debug}, 3 | future::IntoFuture, 4 | }; 5 | 6 | use bytes::Bytes; 7 | use watermelon_proto::{ 8 | MessageBase, Subject, 9 | headers::{HeaderMap, HeaderName, HeaderValue}, 10 | }; 11 | 12 | use crate::{ 13 | client::{Client, ClientClosedError, TryCommandError}, 14 | handler::HandlerCommand, 15 | util::BoxFuture, 16 | }; 17 | 18 | use super::Request; 19 | 20 | /// A publishable message 21 | #[derive(Debug, Clone)] 22 | pub struct Publish { 23 | pub(super) subject: Subject, 24 | pub(super) reply_subject: Option, 25 | pub(super) headers: HeaderMap, 26 | pub(super) payload: Bytes, 27 | } 28 | 29 | /// A constructor for a publishable message 30 | /// 31 | /// Obtained from [`Publish::builder`]. 32 | #[derive(Debug)] 33 | pub struct PublishBuilder { 34 | publish: Publish, 35 | } 36 | 37 | /// A constructor for a publishable message to be sent using the given client 38 | /// 39 | /// Obtained from [`Client::publish`]. 40 | pub struct ClientPublish<'a> { 41 | client: &'a Client, 42 | publish: Publish, 43 | } 44 | 45 | /// A publisheable message ready to be published to the given client 46 | #[must_use = "futures do nothing unless you `.await` or poll them"] 47 | pub struct DoClientPublish<'a> { 48 | client: &'a Client, 49 | publish: Publish, 50 | } 51 | 52 | /// A constructor for a publishable message to be sent using the given owned client 53 | /// 54 | /// Obtained from [`Client::publish_owned`]. 55 | pub struct OwnedClientPublish { 56 | client: Client, 57 | publish: Publish, 58 | } 59 | 60 | /// A publisheable message ready to be published to the given owned client 61 | #[must_use = "futures do nothing unless you `.await` or poll them"] 62 | pub struct DoOwnedClientPublish { 63 | client: Client, 64 | publish: Publish, 65 | } 66 | 67 | macro_rules! publish { 68 | () => { 69 | #[must_use] 70 | pub fn reply_subject(mut self, reply_subject: Option) -> Self { 71 | self.publish_mut().reply_subject = reply_subject; 72 | self 73 | } 74 | 75 | #[must_use] 76 | pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self { 77 | self.publish_mut().headers.insert(name, value); 78 | self 79 | } 80 | 81 | #[must_use] 82 | pub fn headers(mut self, headers: HeaderMap) -> Self { 83 | self.publish_mut().headers = headers; 84 | self 85 | } 86 | }; 87 | } 88 | 89 | impl Publish { 90 | /// Build a new [`Publish`] 91 | #[must_use] 92 | pub fn builder(subject: Subject) -> PublishBuilder { 93 | PublishBuilder::subject(subject) 94 | } 95 | 96 | /// Publish this message to `client` 97 | pub fn client(self, client: &Client) -> DoClientPublish<'_> { 98 | DoClientPublish { 99 | client, 100 | publish: self, 101 | } 102 | } 103 | 104 | /// Publish this message to `client`, taking ownership of it 105 | pub fn client_owned(self, client: Client) -> DoOwnedClientPublish { 106 | DoOwnedClientPublish { 107 | client, 108 | publish: self, 109 | } 110 | } 111 | 112 | pub fn into_request(self) -> Request { 113 | Request { 114 | publish: self, 115 | response_timeout: None, 116 | } 117 | } 118 | 119 | fn into_message_base(self) -> MessageBase { 120 | let Self { 121 | subject, 122 | reply_subject, 123 | headers, 124 | payload, 125 | } = self; 126 | MessageBase { 127 | subject, 128 | reply_subject, 129 | headers, 130 | payload, 131 | } 132 | } 133 | } 134 | 135 | impl PublishBuilder { 136 | #[must_use] 137 | pub fn subject(subject: Subject) -> Self { 138 | Self { 139 | publish: Publish { 140 | subject, 141 | reply_subject: None, 142 | headers: HeaderMap::new(), 143 | payload: Bytes::new(), 144 | }, 145 | } 146 | } 147 | 148 | publish!(); 149 | 150 | #[must_use] 151 | pub fn payload(mut self, payload: Bytes) -> Publish { 152 | self.publish.payload = payload; 153 | self.publish 154 | } 155 | 156 | fn publish_mut(&mut self) -> &mut Publish { 157 | &mut self.publish 158 | } 159 | } 160 | 161 | impl<'a> ClientPublish<'a> { 162 | pub(crate) fn build(client: &'a Client, subject: Subject) -> Self { 163 | Self { 164 | client, 165 | publish: PublishBuilder::subject(subject).publish, 166 | } 167 | } 168 | 169 | publish!(); 170 | 171 | pub fn payload(mut self, payload: Bytes) -> DoClientPublish<'a> { 172 | self.publish.payload = payload; 173 | self.publish.client(self.client) 174 | } 175 | 176 | /// Convert this into [`OwnedClientPublish`] 177 | #[must_use] 178 | pub fn to_owned(self) -> OwnedClientPublish { 179 | OwnedClientPublish { 180 | client: self.client.clone(), 181 | publish: self.publish, 182 | } 183 | } 184 | 185 | fn publish_mut(&mut self) -> &mut Publish { 186 | &mut self.publish 187 | } 188 | } 189 | 190 | impl OwnedClientPublish { 191 | pub(crate) fn build(client: Client, subject: Subject) -> Self { 192 | Self { 193 | client, 194 | publish: PublishBuilder::subject(subject).publish, 195 | } 196 | } 197 | 198 | publish!(); 199 | 200 | pub fn payload(mut self, payload: Bytes) -> DoOwnedClientPublish { 201 | self.publish.payload = payload; 202 | self.publish.client_owned(self.client) 203 | } 204 | 205 | fn publish_mut(&mut self) -> &mut Publish { 206 | &mut self.publish 207 | } 208 | } 209 | 210 | impl DoClientPublish<'_> { 211 | /// Publish this message if there's enough immediately available space in the internal buffers 212 | /// 213 | /// This method will publish the given message only if there's enough 214 | /// immediately available space to enqueue it in the client's 215 | /// networking stack. 216 | /// 217 | /// # Errors 218 | /// 219 | /// It returns an error if the client's buffer is full or if the client has been closed. 220 | pub fn try_publish(self) -> Result<(), TryCommandError> { 221 | try_publish(self.client, self.publish) 222 | } 223 | } 224 | 225 | impl<'a> IntoFuture for DoClientPublish<'a> { 226 | type Output = Result<(), ClientClosedError>; 227 | type IntoFuture = BoxFuture<'a, Self::Output>; 228 | 229 | fn into_future(self) -> Self::IntoFuture { 230 | Box::pin(async move { publish(self.client, self.publish).await }) 231 | } 232 | } 233 | 234 | impl DoOwnedClientPublish { 235 | /// Publish this message if there's enough immediately available space in the internal buffers 236 | /// 237 | /// This method will publish the given message only if there's enough 238 | /// immediately available space to enqueue it in the client's 239 | /// networking stack. 240 | /// 241 | /// # Errors 242 | /// 243 | /// It returns an error if the client's buffer is full or if the client has been closed. 244 | pub fn try_publish(self) -> Result<(), TryCommandError> { 245 | try_publish(&self.client, self.publish) 246 | } 247 | } 248 | 249 | impl IntoFuture for DoOwnedClientPublish { 250 | type Output = Result<(), ClientClosedError>; 251 | type IntoFuture = BoxFuture<'static, Self::Output>; 252 | 253 | fn into_future(self) -> Self::IntoFuture { 254 | Box::pin(async move { publish(&self.client, self.publish).await }) 255 | } 256 | } 257 | 258 | fn try_publish(client: &Client, publish: Publish) -> Result<(), TryCommandError> { 259 | client.try_enqueue_command(HandlerCommand::Publish { 260 | message: publish.into_message_base(), 261 | }) 262 | } 263 | 264 | async fn publish(client: &Client, publish: Publish) -> Result<(), ClientClosedError> { 265 | client 266 | .enqueue_command(HandlerCommand::Publish { 267 | message: publish.into_message_base(), 268 | }) 269 | .await 270 | } 271 | 272 | impl Debug for ClientPublish<'_> { 273 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 274 | f.debug_struct("ClientPublish") 275 | .field("publish", &self.publish) 276 | .finish_non_exhaustive() 277 | } 278 | } 279 | 280 | impl Debug for DoClientPublish<'_> { 281 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 282 | f.debug_struct("DoClientPublish") 283 | .field("publish", &self.publish) 284 | .finish_non_exhaustive() 285 | } 286 | } 287 | 288 | impl Debug for OwnedClientPublish { 289 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 290 | f.debug_struct("OwnedClientPublish") 291 | .field("publish", &self.publish) 292 | .finish_non_exhaustive() 293 | } 294 | } 295 | 296 | impl Debug for DoOwnedClientPublish { 297 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 298 | f.debug_struct("DoOwnedClientPublish") 299 | .field("publish", &self.publish) 300 | .finish_non_exhaustive() 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /watermelon/src/client/from_env.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use serde::{Deserialize, Deserializer, de}; 4 | use watermelon_nkeys::KeyPair; 5 | use watermelon_proto::Subject; 6 | 7 | #[derive(Debug, Deserialize)] 8 | pub(super) struct FromEnv { 9 | #[serde(flatten)] 10 | pub(super) auth: AuthenticationMethod, 11 | pub(super) inbox_prefix: Option, 12 | } 13 | 14 | #[derive(Debug, Deserialize)] 15 | #[serde(untagged)] 16 | pub(super) enum AuthenticationMethod { 17 | Creds { 18 | #[serde(rename = "nats_jwt")] 19 | jwt: String, 20 | #[serde(rename = "nats_nkey", deserialize_with = "deserialize_nkey")] 21 | nkey: KeyPair, 22 | }, 23 | CredsFile { 24 | #[serde(rename = "nats_creds_file")] 25 | creds_file: PathBuf, 26 | }, 27 | UserAndPassword { 28 | #[serde(rename = "nats_username")] 29 | username: String, 30 | #[serde(rename = "nats_password")] 31 | password: String, 32 | }, 33 | None, 34 | } 35 | 36 | fn deserialize_nkey<'de, D>(deserializer: D) -> Result 37 | where 38 | D: Deserializer<'de>, 39 | { 40 | let secret = String::deserialize(deserializer)?; 41 | KeyPair::from_encoded_seed(&secret).map_err(de::Error::custom) 42 | } 43 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/commands/consumer_batch.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | pin::Pin, 3 | task::{Context, Poll}, 4 | time::Duration, 5 | }; 6 | 7 | use futures_core::{FusedStream, Future, Stream}; 8 | use pin_project_lite::pin_project; 9 | use serde_json::json; 10 | use tokio::time::{Sleep, sleep}; 11 | use watermelon_proto::{ServerMessage, StatusCode, error::ServerError}; 12 | 13 | use crate::{ 14 | client::{Consumer, JetstreamClient, JetstreamError}, 15 | subscription::Subscription, 16 | }; 17 | 18 | pin_project! { 19 | /// A consumer batch request 20 | /// 21 | /// Obtained from [`JetstreamClient::consumer_batch`]. 22 | #[derive(Debug)] 23 | #[must_use = "streams do nothing unless polled"] 24 | pub struct ConsumerBatch { 25 | subscription: Subscription, 26 | #[pin] 27 | timeout: Sleep, 28 | pending_msgs: usize, 29 | } 30 | } 31 | 32 | #[derive(Debug, thiserror::Error)] 33 | pub enum ConsumerBatchError { 34 | #[error("an error returned by the server")] 35 | ServerError(#[source] ServerError), 36 | #[error("unexpected status code")] 37 | UnexpectedStatus(ServerMessage), 38 | } 39 | 40 | impl ConsumerBatch { 41 | pub(crate) fn new( 42 | consumer: &Consumer, 43 | client: JetstreamClient, 44 | expires: Duration, 45 | max_msgs: usize, 46 | ) -> impl Future> + use<> { 47 | let subject = format!( 48 | "{}.CONSUMER.MSG.NEXT.{}.{}", 49 | client.prefix, consumer.stream_name, consumer.config.name 50 | ) 51 | .try_into(); 52 | 53 | async move { 54 | let subject = subject.map_err(JetstreamError::Subject)?; 55 | let incoming_subject = client.client.create_inbox_subject(); 56 | let payload = serde_json::to_vec(&if expires.is_zero() { 57 | json!({ 58 | "batch": max_msgs, 59 | "no_wait": true, 60 | }) 61 | } else { 62 | json!({ 63 | "batch": max_msgs, 64 | "expires": expires.as_nanos(), 65 | "no_wait": true 66 | }) 67 | }) 68 | .map_err(JetstreamError::Json)?; 69 | 70 | let subscription = client 71 | .client 72 | .subscribe(incoming_subject.clone(), None) 73 | .await 74 | .map_err(JetstreamError::ClientClosed)?; 75 | client 76 | .client 77 | .publish(subject) 78 | .reply_subject(Some(incoming_subject.clone())) 79 | .payload(payload.into()) 80 | .await 81 | .map_err(JetstreamError::ClientClosed)?; 82 | 83 | let timeout = sleep(expires.saturating_add(client.request_timeout)); 84 | Ok(Self { 85 | subscription, 86 | timeout, 87 | pending_msgs: max_msgs, 88 | }) 89 | } 90 | } 91 | } 92 | 93 | impl Stream for ConsumerBatch { 94 | type Item = Result; 95 | 96 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 97 | let this = self.project(); 98 | 99 | if *this.pending_msgs == 0 { 100 | return Poll::Ready(None); 101 | } 102 | 103 | match Pin::new(this.subscription).poll_next(cx) { 104 | Poll::Pending => match this.timeout.poll(cx) { 105 | Poll::Pending => Poll::Pending, 106 | Poll::Ready(()) => { 107 | *this.pending_msgs = 0; 108 | Poll::Ready(None) 109 | } 110 | }, 111 | Poll::Ready(Some(Ok(msg))) => match msg.status_code { 112 | None | Some(StatusCode::OK) => { 113 | *this.pending_msgs -= 1; 114 | 115 | Poll::Ready(Some(Ok(msg))) 116 | } 117 | Some(StatusCode::IDLE_HEARTBEAT) => { 118 | cx.waker().wake_by_ref(); 119 | Poll::Pending 120 | } 121 | Some(StatusCode::TIMEOUT | StatusCode::NOT_FOUND) => { 122 | *this.pending_msgs = 0; 123 | Poll::Ready(None) 124 | } 125 | _ => Poll::Ready(Some(Err(ConsumerBatchError::UnexpectedStatus(msg)))), 126 | }, 127 | Poll::Ready(Some(Err(err))) => { 128 | *this.pending_msgs = 0; 129 | Poll::Ready(Some(Err(ConsumerBatchError::ServerError(err)))) 130 | } 131 | Poll::Ready(None) => { 132 | *this.pending_msgs = 0; 133 | Poll::Ready(None) 134 | } 135 | } 136 | } 137 | } 138 | 139 | impl FusedStream for ConsumerBatch { 140 | fn is_terminated(&self) -> bool { 141 | self.pending_msgs == 0 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/commands/consumer_list.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::VecDeque, 3 | fmt::Display, 4 | future::Future, 5 | pin::Pin, 6 | task::{Context, Poll}, 7 | }; 8 | 9 | use futures_core::{FusedStream, Stream}; 10 | use serde::Deserialize; 11 | use serde_json::json; 12 | use watermelon_proto::Subject; 13 | 14 | use crate::{ 15 | client::{self, JetstreamClient, jetstream::JetstreamError}, 16 | util::BoxFuture, 17 | }; 18 | 19 | /// A request to list consumers of a stream 20 | /// 21 | /// Obtained from [`JetstreamClient::consumers`]. 22 | #[must_use = "streams do nothing unless polled"] 23 | pub struct Consumers { 24 | client: JetstreamClient, 25 | offset: u32, 26 | partial_subject: Subject, 27 | fetch: Option>>, 28 | buffer: VecDeque, 29 | exhausted: bool, 30 | } 31 | 32 | #[derive(Debug, Deserialize)] 33 | struct ConsumersResponse { 34 | limit: u32, 35 | consumers: VecDeque, 36 | } 37 | 38 | impl Consumers { 39 | pub(crate) fn new(client: JetstreamClient, stream_name: impl Display) -> Self { 40 | let partial_subject = format!("CONSUMER.LIST.{stream_name}") 41 | .try_into() 42 | .expect("stream name is valid"); 43 | Self { 44 | client, 45 | offset: 0, 46 | partial_subject, 47 | fetch: None, 48 | buffer: VecDeque::new(), 49 | exhausted: false, 50 | } 51 | } 52 | } 53 | 54 | impl Stream for Consumers { 55 | type Item = Result; 56 | 57 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 58 | let this = self.get_mut(); 59 | 60 | if let Some(consumer) = this.buffer.pop_front() { 61 | return Poll::Ready(Some(Ok(consumer))); 62 | } 63 | 64 | if this.exhausted { 65 | return Poll::Ready(None); 66 | } 67 | 68 | let fetch = this.fetch.get_or_insert_with(|| { 69 | let client = this.client.clone(); 70 | let partial_subject = this.partial_subject.clone(); 71 | let offset = this.offset; 72 | 73 | Box::pin(async move { 74 | let response_fut = client 75 | .client() 76 | .request(client.subject_for_request(&partial_subject)) 77 | .response_timeout(client.request_timeout) 78 | .payload( 79 | serde_json::to_vec(&json!({ 80 | "offset": offset, 81 | })) 82 | .unwrap() 83 | .into(), 84 | ) 85 | .await 86 | .map_err(JetstreamError::ClientClosed)?; 87 | let response = response_fut.await.map_err(JetstreamError::ResponseError)?; 88 | let payload = 89 | serde_json::from_slice(&response.base.payload).map_err(JetstreamError::Json)?; 90 | Ok(payload) 91 | }) 92 | }); 93 | 94 | match Pin::new(fetch).poll(cx) { 95 | Poll::Pending => Poll::Pending, 96 | Poll::Ready(Ok(response)) => { 97 | this.fetch = None; 98 | this.buffer = response.consumers; 99 | if this.buffer.len() < response.limit as usize { 100 | this.exhausted = true; 101 | } else if !this.buffer.is_empty() { 102 | this.offset += 1; 103 | } 104 | 105 | cx.waker().wake_by_ref(); 106 | Poll::Pending 107 | } 108 | Poll::Ready(Err(err)) => { 109 | this.fetch = None; 110 | Poll::Ready(Some(Err(err))) 111 | } 112 | } 113 | } 114 | } 115 | 116 | impl FusedStream for Consumers { 117 | fn is_terminated(&self) -> bool { 118 | self.buffer.is_empty() && self.exhausted 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/commands/consumer_stream.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | time::Duration, 6 | }; 7 | 8 | use futures_core::{FusedStream, Stream}; 9 | use pin_project_lite::pin_project; 10 | use watermelon_proto::ServerMessage; 11 | 12 | use crate::{ 13 | client::{Consumer, JetstreamClient, JetstreamError}, 14 | util::BoxFuture, 15 | }; 16 | 17 | use super::{ConsumerBatch, consumer_batch::ConsumerBatchError}; 18 | 19 | pin_project! { 20 | /// A consumer stream of batch requests 21 | /// 22 | /// Obtained from [`JetstreamClient::consumer_stream`]. 23 | #[must_use = "streams do nothing unless polled"] 24 | pub struct ConsumerStream { 25 | #[pin] 26 | status: ConsumerStreamStatus, 27 | consumer: Consumer, 28 | client: JetstreamClient, 29 | 30 | expires: Duration, 31 | max_msgs: usize, 32 | } 33 | } 34 | 35 | pin_project! { 36 | #[project = ConsumerStreamStatusProj] 37 | enum ConsumerStreamStatus { 38 | Polling { 39 | future: BoxFuture<'static, Result>, 40 | }, 41 | RunningBatch { 42 | #[pin] 43 | batch: ConsumerBatch, 44 | }, 45 | Broken, 46 | } 47 | } 48 | 49 | #[derive(Debug, thiserror::Error)] 50 | pub enum ConsumerStreamError { 51 | #[error("consumer batch error")] 52 | BatchError(#[source] ConsumerBatchError), 53 | #[error("jetstream error")] 54 | Jetstream(#[source] JetstreamError), 55 | } 56 | 57 | impl ConsumerStream { 58 | pub(crate) fn new( 59 | consumer: Consumer, 60 | client: JetstreamClient, 61 | expires: Duration, 62 | max_msgs: usize, 63 | ) -> Self { 64 | let poll_fut = { 65 | let client = client.clone(); 66 | Box::pin(ConsumerBatch::new(&consumer, client, expires, max_msgs)) 67 | }; 68 | 69 | Self { 70 | status: ConsumerStreamStatus::Polling { future: poll_fut }, 71 | consumer, 72 | client, 73 | 74 | expires, 75 | max_msgs, 76 | } 77 | } 78 | } 79 | 80 | impl Stream for ConsumerStream { 81 | type Item = Result; 82 | 83 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 84 | let mut this = self.project(); 85 | match this.status.as_mut().project() { 86 | ConsumerStreamStatusProj::RunningBatch { batch } => match batch.poll_next(cx) { 87 | Poll::Pending => Poll::Pending, 88 | Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(msg))), 89 | Poll::Ready(Some(Err(err))) => { 90 | this.status.set(ConsumerStreamStatus::Broken); 91 | Poll::Ready(Some(Err(ConsumerStreamError::BatchError(err)))) 92 | } 93 | Poll::Ready(None) => { 94 | this.status.set(ConsumerStreamStatus::Polling { 95 | future: Box::pin(ConsumerBatch::new( 96 | this.consumer, 97 | this.client.clone(), 98 | *this.expires, 99 | *this.max_msgs, 100 | )), 101 | }); 102 | 103 | cx.waker().wake_by_ref(); 104 | Poll::Pending 105 | } 106 | }, 107 | ConsumerStreamStatusProj::Polling { future: fut } => match Pin::new(fut).poll(cx) { 108 | Poll::Pending => Poll::Pending, 109 | Poll::Ready(Ok(batch)) => { 110 | this.status 111 | .set(ConsumerStreamStatus::RunningBatch { batch }); 112 | 113 | cx.waker().wake_by_ref(); 114 | Poll::Pending 115 | } 116 | Poll::Ready(Err(err)) => { 117 | this.status.set(ConsumerStreamStatus::Broken); 118 | Poll::Ready(Some(Err(ConsumerStreamError::Jetstream(err)))) 119 | } 120 | }, 121 | ConsumerStreamStatusProj::Broken => Poll::Ready(None), 122 | } 123 | } 124 | } 125 | 126 | impl FusedStream for ConsumerStream { 127 | fn is_terminated(&self) -> bool { 128 | matches!(self.status, ConsumerStreamStatus::Broken) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/commands/mod.rs: -------------------------------------------------------------------------------- 1 | pub use self::consumer_batch::ConsumerBatch; 2 | pub use self::consumer_list::Consumers; 3 | pub use self::consumer_stream::{ConsumerStream, ConsumerStreamError}; 4 | pub use self::stream_list::Streams; 5 | 6 | mod consumer_batch; 7 | mod consumer_list; 8 | mod consumer_stream; 9 | mod stream_list; 10 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/commands/stream_list.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::VecDeque, 3 | future::Future, 4 | pin::Pin, 5 | task::{Context, Poll}, 6 | }; 7 | 8 | use futures_core::{FusedStream, Stream}; 9 | use serde::Deserialize; 10 | use serde_json::json; 11 | use watermelon_proto::Subject; 12 | 13 | use crate::{ 14 | client::{self, JetstreamClient, jetstream::JetstreamError}, 15 | util::BoxFuture, 16 | }; 17 | 18 | /// A request to list streams 19 | /// 20 | /// Obtained from [`JetstreamClient::streams`]. 21 | #[must_use = "streams do nothing unless polled"] 22 | pub struct Streams { 23 | client: JetstreamClient, 24 | offset: u32, 25 | fetch: Option>>, 26 | buffer: VecDeque, 27 | exhausted: bool, 28 | } 29 | 30 | #[derive(Debug, Deserialize)] 31 | struct StreamsResponse { 32 | limit: u32, 33 | streams: VecDeque, 34 | } 35 | 36 | impl Streams { 37 | pub(crate) fn new(client: JetstreamClient) -> Self { 38 | Self { 39 | client, 40 | offset: 0, 41 | fetch: None, 42 | buffer: VecDeque::new(), 43 | exhausted: false, 44 | } 45 | } 46 | } 47 | 48 | impl Stream for Streams { 49 | type Item = Result; 50 | 51 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 52 | let this = self.get_mut(); 53 | 54 | if let Some(stream) = this.buffer.pop_front() { 55 | return Poll::Ready(Some(Ok(stream))); 56 | } 57 | 58 | if this.exhausted { 59 | return Poll::Ready(None); 60 | } 61 | 62 | let fetch = this.fetch.get_or_insert_with(|| { 63 | let client = this.client.clone(); 64 | let offset = this.offset; 65 | 66 | Box::pin(async move { 67 | let response_fut = client 68 | .client() 69 | .request(client.subject_for_request(&Subject::from_static("STREAM.LIST"))) 70 | .response_timeout(client.request_timeout) 71 | .payload( 72 | serde_json::to_vec(&json!({ 73 | "offset": offset, 74 | })) 75 | .unwrap() 76 | .into(), 77 | ) 78 | .await 79 | .map_err(JetstreamError::ClientClosed)?; 80 | let response = response_fut.await.map_err(JetstreamError::ResponseError)?; 81 | let payload = 82 | serde_json::from_slice(&response.base.payload).map_err(JetstreamError::Json)?; 83 | Ok(payload) 84 | }) 85 | }); 86 | 87 | match Pin::new(fetch).poll(cx) { 88 | Poll::Pending => Poll::Pending, 89 | Poll::Ready(Ok(response)) => { 90 | this.fetch = None; 91 | this.buffer = response.streams; 92 | if this.buffer.len() < response.limit as usize { 93 | this.exhausted = true; 94 | } else if !this.buffer.is_empty() { 95 | this.offset += 1; 96 | } 97 | 98 | cx.waker().wake_by_ref(); 99 | Poll::Pending 100 | } 101 | Poll::Ready(Err(err)) => { 102 | this.fetch = None; 103 | Poll::Ready(Some(Err(err))) 104 | } 105 | } 106 | } 107 | } 108 | 109 | impl FusedStream for Streams { 110 | fn is_terminated(&self) -> bool { 111 | self.buffer.is_empty() && self.exhausted 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/resources/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | 3 | pub use self::consumer::{ 4 | AckPolicy, Consumer, ConsumerConfig, ConsumerDurability, ConsumerSpecificConfig, 5 | ConsumerStorage, DeliverPolicy, ReplayPolicy, 6 | }; 7 | pub use self::stream::{ 8 | Compression, DiscardPolicy, RetentionPolicy, Storage, Stream, StreamConfig, StreamState, 9 | }; 10 | 11 | use super::JetstreamApiError; 12 | 13 | mod consumer; 14 | mod stream; 15 | 16 | #[derive(Debug, Deserialize)] 17 | #[serde(untagged)] 18 | pub(crate) enum Response { 19 | Response(T), 20 | Error { error: JetstreamApiError }, 21 | } 22 | 23 | mod nullable_number { 24 | use std::{any::type_name, fmt::Display}; 25 | 26 | use serde::{ 27 | Deserialize, Deserializer, Serialize, Serializer, 28 | de::{self, DeserializeOwned}, 29 | ser, 30 | }; 31 | 32 | pub(crate) trait NullableNumber: Copy + Display { 33 | const NULL_VALUE: Self::SignedValue; 34 | type SignedValue: Copy 35 | + TryFrom 36 | + TryInto 37 | + Display 38 | + Eq 39 | + Serialize 40 | + DeserializeOwned; 41 | } 42 | 43 | impl NullableNumber for u32 { 44 | const NULL_VALUE: Self::SignedValue = -1; 45 | type SignedValue = i32; 46 | } 47 | 48 | impl NullableNumber for u64 { 49 | const NULL_VALUE: Self::SignedValue = -1; 50 | type SignedValue = i64; 51 | } 52 | 53 | #[expect(clippy::ref_option)] 54 | pub(crate) fn serialize(num: &Option, serializer: S) -> Result 55 | where 56 | S: Serializer, 57 | N: NullableNumber, 58 | { 59 | match *num { 60 | Some(num) => num.try_into().map_err(|_| { 61 | ser::Error::custom(format!( 62 | "{num} can't be converted to {}", 63 | type_name::() 64 | )) 65 | })?, 66 | None => N::NULL_VALUE, 67 | } 68 | .serialize(serializer) 69 | } 70 | 71 | pub(crate) fn deserialize<'de, D: Deserializer<'de>, N: NullableNumber>( 72 | deserializer: D, 73 | ) -> Result, D::Error> { 74 | let num = N::SignedValue::deserialize(deserializer)?; 75 | Ok(if num == N::NULL_VALUE { 76 | None 77 | } else { 78 | Some(num.try_into().map_err(|_| { 79 | de::Error::custom(format!("{num} can't be converted to {}", type_name::())) 80 | })?) 81 | }) 82 | } 83 | } 84 | 85 | mod option_nonzero { 86 | use std::num::NonZero; 87 | 88 | use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned}; 89 | 90 | pub(crate) trait NonZeroNumber: Copy { 91 | type Inner: Copy + Default + From + TryInto + Serialize + DeserializeOwned; 92 | } 93 | 94 | impl NonZeroNumber for NonZero { 95 | type Inner = u32; 96 | } 97 | 98 | #[expect(clippy::ref_option)] 99 | pub(crate) fn serialize(num: &Option, serializer: S) -> Result 100 | where 101 | S: Serializer, 102 | N: NonZeroNumber, 103 | { 104 | match *num { 105 | Some(num) => >::from(num), 106 | None => Default::default(), 107 | } 108 | .serialize(serializer) 109 | } 110 | 111 | pub(crate) fn deserialize<'de, D: Deserializer<'de>, N: NonZeroNumber>( 112 | deserializer: D, 113 | ) -> Result, D::Error> { 114 | let num = ::deserialize(deserializer)?; 115 | Ok(num.try_into().ok()) 116 | } 117 | } 118 | 119 | mod nullable_datetime { 120 | use jiff::Timestamp; 121 | use serde::{Deserialize, Deserializer}; 122 | 123 | // 0001-01-01T00:00:00Z 124 | const GOLANG_ZERO: Timestamp = Timestamp::constant(-62_135_596_800, 0); 125 | 126 | pub(crate) fn deserialize<'de, D: Deserializer<'de>>( 127 | deserializer: D, 128 | ) -> Result, D::Error> { 129 | let datetime = ::deserialize(deserializer)?; 130 | Ok(if datetime == GOLANG_ZERO { 131 | None 132 | } else { 133 | Some(datetime) 134 | }) 135 | } 136 | } 137 | 138 | mod duration { 139 | use std::time::Duration; 140 | 141 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 142 | 143 | pub(crate) fn serialize(duration: &Duration, serializer: S) -> Result 144 | where 145 | S: Serializer, 146 | { 147 | duration.as_nanos().serialize(serializer) 148 | } 149 | 150 | pub(crate) fn deserialize<'de, D: Deserializer<'de>>( 151 | deserializer: D, 152 | ) -> Result { 153 | Ok(Duration::from_nanos(u64::deserialize(deserializer)?)) 154 | } 155 | } 156 | 157 | mod duration_vec { 158 | use std::time::Duration; 159 | 160 | use serde::{Deserialize, Deserializer, Serializer}; 161 | 162 | #[expect( 163 | clippy::ptr_arg, 164 | reason = "this must follow the signature expected by serde" 165 | )] 166 | pub(crate) fn serialize(durations: &Vec, serializer: S) -> Result 167 | where 168 | S: Serializer, 169 | { 170 | serializer.collect_seq(durations.iter().map(std::time::Duration::as_nanos)) 171 | } 172 | 173 | pub(crate) fn deserialize<'de, D: Deserializer<'de>>( 174 | deserializer: D, 175 | ) -> Result, D::Error> { 176 | let durations = as Deserialize>::deserialize(deserializer)?; 177 | Ok(durations.into_iter().map(Duration::from_nanos).collect()) 178 | } 179 | } 180 | 181 | mod compression { 182 | #[derive(Debug, Serialize, Deserialize)] 183 | #[serde(rename_all = "snake_case")] 184 | enum CompressionInner { 185 | None, 186 | S2, 187 | } 188 | 189 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 190 | 191 | use super::Compression; 192 | 193 | #[expect(clippy::ref_option)] 194 | pub(crate) fn serialize( 195 | compression: &Option, 196 | serializer: S, 197 | ) -> Result 198 | where 199 | S: Serializer, 200 | { 201 | match compression { 202 | None => CompressionInner::None, 203 | Some(Compression::S2) => CompressionInner::S2, 204 | } 205 | .serialize(serializer) 206 | } 207 | 208 | pub(crate) fn deserialize<'de, D: Deserializer<'de>>( 209 | deserializer: D, 210 | ) -> Result, D::Error> { 211 | Ok(match CompressionInner::deserialize(deserializer)? { 212 | CompressionInner::None => None, 213 | CompressionInner::S2 => Some(Compression::S2), 214 | }) 215 | } 216 | } 217 | 218 | mod opposite_bool { 219 | use std::ops::Not; 220 | 221 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 222 | 223 | #[expect( 224 | clippy::trivially_copy_pass_by_ref, 225 | reason = "this must follow the signature expected by serde" 226 | )] 227 | pub(crate) fn serialize(val: &bool, serializer: S) -> Result 228 | where 229 | S: Serializer, 230 | { 231 | val.not().serialize(serializer) 232 | } 233 | 234 | pub(crate) fn deserialize<'de, D: Deserializer<'de>>( 235 | deserializer: D, 236 | ) -> Result { 237 | bool::deserialize(deserializer).map(Not::not) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /watermelon/src/client/jetstream/resources/stream.rs: -------------------------------------------------------------------------------- 1 | use std::{num::NonZero, time::Duration}; 2 | 3 | use jiff::Timestamp; 4 | use serde::{Deserialize, Serialize}; 5 | use watermelon_proto::Subject; 6 | 7 | use super::{compression, duration, nullable_datetime, nullable_number, opposite_bool}; 8 | 9 | /// A Jetstream stream 10 | #[derive(Debug, Deserialize)] 11 | pub struct Stream { 12 | pub config: StreamConfig, 13 | #[serde(rename = "created")] 14 | pub created_at: Timestamp, 15 | // TODO: `cluster` 16 | } 17 | 18 | /// The state of the stream 19 | #[derive(Debug, Deserialize)] 20 | pub struct StreamState { 21 | pub messages: u64, 22 | pub bytes: u64, 23 | pub first_sequence: u64, 24 | #[serde(with = "nullable_datetime", rename = "first_ts")] 25 | pub first_sequence_timestamp: Option, 26 | pub last_sequence: u64, 27 | #[serde(with = "nullable_datetime", rename = "last_ts")] 28 | pub last_sequence_timestamp: Option, 29 | pub consumer_count: u32, 30 | } 31 | 32 | /// A Jetstream stream configuration 33 | #[derive(Debug, Serialize, Deserialize)] 34 | #[expect( 35 | clippy::struct_excessive_bools, 36 | reason = "it is the actual config of a Jetstream" 37 | )] 38 | pub struct StreamConfig { 39 | pub name: String, 40 | pub subjects: Vec, 41 | #[serde(with = "nullable_number")] 42 | pub max_consumers: Option, 43 | #[serde(with = "nullable_number", rename = "max_msgs")] 44 | pub max_messages: Option, 45 | #[serde(with = "nullable_number")] 46 | pub max_bytes: Option, 47 | #[serde(with = "duration")] 48 | pub max_age: Duration, 49 | #[serde(with = "nullable_number", rename = "max_msgs_per_subject")] 50 | pub max_messages_per_subject: Option, 51 | #[serde(with = "nullable_number", rename = "max_msg_size")] 52 | pub max_message_size: Option, 53 | #[serde(rename = "discard")] 54 | pub discard_policy: DiscardPolicy, 55 | pub storage: Storage, 56 | #[serde(rename = "num_replicas")] 57 | pub replicas: NonZero, 58 | #[serde(with = "duration")] 59 | pub duplicate_window: Duration, 60 | #[serde(with = "compression")] 61 | pub compression: Option, 62 | pub allow_direct: bool, 63 | pub mirror_direct: bool, 64 | pub sealed: bool, 65 | #[serde(with = "opposite_bool", rename = "deny_delete")] 66 | pub allow_delete: bool, 67 | #[serde(with = "opposite_bool", rename = "deny_purge")] 68 | pub allow_purge: bool, 69 | pub allow_rollup_hdrs: bool, 70 | // TODO: `consumer_limits` https://github.com/nats-io/nats-server/blob/e25d973a8f389ce3aa415e4bcdfba1f7d0834f7f/server/stream.go#L99 71 | } 72 | 73 | /// A streams retention policy 74 | #[derive(Debug, Serialize, Deserialize)] 75 | #[serde(rename_all = "snake_case")] 76 | pub enum RetentionPolicy { 77 | Limits, 78 | Interest, 79 | WorkQueue, 80 | } 81 | 82 | /// A streams discard policy 83 | #[derive(Debug, Serialize, Deserialize)] 84 | #[serde(rename_all = "snake_case")] 85 | pub enum DiscardPolicy { 86 | Old, 87 | New, 88 | } 89 | 90 | /// Whether the disk is stored on disk or in memory 91 | #[derive(Debug, Serialize, Deserialize)] 92 | #[serde(rename_all = "snake_case")] 93 | pub enum Storage { 94 | File, 95 | Memory, 96 | } 97 | 98 | /// The compression algorithm used by a stream 99 | #[derive(Debug)] 100 | pub enum Compression { 101 | S2, 102 | } 103 | -------------------------------------------------------------------------------- /watermelon/src/client/quick_info.rs: -------------------------------------------------------------------------------- 1 | use crate::util::atomic::{AtomicU32, Ordering}; 2 | 3 | const IS_CONNECTED: u32 = 1 << 0; 4 | #[cfg(feature = "non-standard-zstd")] 5 | const IS_ZSTD_COMPRESSED: u32 = 1 << 1; 6 | const IS_LAMEDUCK: u32 = 1 << 2; 7 | const IS_FAILED_UNSUBSCRIBE: u32 = 1 << 31; 8 | 9 | #[derive(Debug)] 10 | pub(crate) struct RawQuickInfo(AtomicU32); 11 | 12 | /// Client information 13 | /// 14 | /// Obtained from [`Client::quick_info`]. 15 | /// 16 | /// [`Client::quick_info`]: crate::core::Client::quick_info 17 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 18 | #[cfg_attr(feature = "non-standard-zstd", expect(clippy::struct_excessive_bools))] 19 | pub struct QuickInfo { 20 | pub(crate) is_connected: bool, 21 | #[cfg(feature = "non-standard-zstd")] 22 | pub(crate) is_zstd_compressed: bool, 23 | pub(crate) is_lameduck: bool, 24 | pub(crate) is_failed_unsubscribe: bool, 25 | } 26 | 27 | impl RawQuickInfo { 28 | pub(crate) fn new() -> Self { 29 | Self(AtomicU32::new( 30 | QuickInfo { 31 | is_connected: false, 32 | #[cfg(feature = "non-standard-zstd")] 33 | is_zstd_compressed: false, 34 | is_lameduck: false, 35 | is_failed_unsubscribe: false, 36 | } 37 | .encode(), 38 | )) 39 | } 40 | 41 | pub(crate) fn get(&self) -> QuickInfo { 42 | QuickInfo::decode(self.0.load(Ordering::Acquire)) 43 | } 44 | 45 | pub(crate) fn store(&self, mut f: F) 46 | where 47 | F: FnMut(QuickInfo) -> QuickInfo, 48 | { 49 | let prev_params = self.get(); 50 | self.0.store(f(prev_params).encode(), Ordering::Release); 51 | } 52 | 53 | pub(crate) fn store_is_connected(&self, val: bool) { 54 | self.store_bit(IS_CONNECTED, val); 55 | } 56 | pub(crate) fn store_is_lameduck(&self, val: bool) { 57 | self.store_bit(IS_LAMEDUCK, val); 58 | } 59 | pub(crate) fn store_is_failed_unsubscribe(&self, val: bool) { 60 | self.store_bit(IS_FAILED_UNSUBSCRIBE, val); 61 | } 62 | 63 | #[expect( 64 | clippy::inline_always, 65 | reason = "we want this to be inlined inside the store_* functions" 66 | )] 67 | #[inline(always)] 68 | fn store_bit(&self, mask: u32, val: bool) { 69 | debug_assert_eq!(mask.count_ones(), 1); 70 | 71 | if val { 72 | self.0.fetch_or(mask, Ordering::AcqRel); 73 | } else { 74 | self.0.fetch_and(!mask, Ordering::AcqRel); 75 | } 76 | } 77 | } 78 | 79 | impl QuickInfo { 80 | /// Returns `true` if the client is currently connected to the NATS server 81 | #[must_use] 82 | pub fn is_connected(&self) -> bool { 83 | self.is_connected 84 | } 85 | 86 | /// Returns `true` if the client connection is zstd compressed 87 | #[cfg(feature = "non-standard-zstd")] 88 | #[must_use] 89 | pub fn is_zstd_compressed(&self) -> bool { 90 | self.is_zstd_compressed 91 | } 92 | 93 | /// Returns `true` if the client is currently in Lame Duck Mode 94 | #[must_use] 95 | pub fn is_lameduck(&self) -> bool { 96 | self.is_lameduck 97 | } 98 | 99 | fn encode(self) -> u32 { 100 | let mut val = 0; 101 | 102 | if self.is_connected { 103 | val |= IS_CONNECTED; 104 | } 105 | 106 | #[cfg(feature = "non-standard-zstd")] 107 | if self.is_zstd_compressed { 108 | val |= IS_ZSTD_COMPRESSED; 109 | } 110 | 111 | if self.is_lameduck { 112 | val |= IS_LAMEDUCK; 113 | } 114 | 115 | if self.is_failed_unsubscribe { 116 | val |= IS_FAILED_UNSUBSCRIBE; 117 | } 118 | 119 | val 120 | } 121 | 122 | fn decode(val: u32) -> Self { 123 | Self { 124 | is_connected: (val & IS_CONNECTED) != 0, 125 | #[cfg(feature = "non-standard-zstd")] 126 | is_zstd_compressed: (val & IS_ZSTD_COMPRESSED) != 0, 127 | is_lameduck: (val & IS_LAMEDUCK) != 0, 128 | is_failed_unsubscribe: (val & IS_FAILED_UNSUBSCRIBE) != 0, 129 | } 130 | } 131 | } 132 | 133 | #[cfg(test)] 134 | mod tests { 135 | use super::{QuickInfo, RawQuickInfo}; 136 | 137 | #[test] 138 | fn set_get() { 139 | let quick_info = RawQuickInfo::new(); 140 | let mut expected = QuickInfo { 141 | is_connected: false, 142 | #[cfg(feature = "non-standard-zstd")] 143 | is_zstd_compressed: false, 144 | is_lameduck: false, 145 | is_failed_unsubscribe: false, 146 | }; 147 | 148 | for is_connected in [false, true] { 149 | quick_info.store_is_connected(is_connected); 150 | expected.is_connected = is_connected; 151 | 152 | for is_lameduck in [false, true] { 153 | quick_info.store_is_lameduck(is_lameduck); 154 | expected.is_lameduck = is_lameduck; 155 | 156 | for is_failed_unsubscribe in [false, true] { 157 | quick_info.store_is_failed_unsubscribe(is_failed_unsubscribe); 158 | expected.is_failed_unsubscribe = is_failed_unsubscribe; 159 | 160 | assert_eq!(expected, quick_info.get()); 161 | } 162 | } 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /watermelon/src/client/tests.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::BTreeSet, sync::Arc}; 2 | 3 | use arc_swap::ArcSwapOption; 4 | use tokio::sync::mpsc; 5 | use watermelon_proto::{ServerInfo, Subject}; 6 | 7 | use crate::{ 8 | client::{RawQuickInfo, create_inbox_subject}, 9 | handler::HandlerCommand, 10 | }; 11 | 12 | #[derive(Debug)] 13 | pub(crate) struct TestHandler { 14 | pub(crate) receiver: mpsc::Receiver, 15 | pub(crate) _info: Arc>, 16 | pub(crate) quick_info: Arc, 17 | } 18 | 19 | #[test] 20 | fn unique_create_inbox_subject() { 21 | const ITERATIONS: usize = if cfg!(miri) { 100 } else { 100_000 }; 22 | 23 | let prefix = Subject::from_static("abcd"); 24 | let subjects = (0..ITERATIONS) 25 | .map(|_| create_inbox_subject(&prefix)) 26 | .collect::>(); 27 | assert_eq!(subjects.len(), ITERATIONS); 28 | } 29 | -------------------------------------------------------------------------------- /watermelon/src/handler/delayed.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | mem, 4 | pin::Pin, 5 | task::{Context, Poll}, 6 | time::Duration, 7 | }; 8 | 9 | use tokio::time::{Instant, Sleep, sleep}; 10 | 11 | /// A delay mechanism that enforces a minimum duration between operations. 12 | /// 13 | /// `Delayed` ensures that successive calls to `poll_can_proceed` are separated 14 | /// by at least the specified duration. 15 | #[derive(Debug)] 16 | pub(super) struct Delayed { 17 | inner: Option, 18 | } 19 | 20 | #[derive(Debug)] 21 | struct DelayedInner { 22 | // INVARIANT: `duration != Duration::ZERO` 23 | duration: Duration, 24 | delay: Pin>, 25 | delay_consumed: bool, 26 | } 27 | 28 | impl Delayed { 29 | /// Create a new `Delayed` with the specified duration. 30 | /// 31 | /// If `duration` is zero, the delay is effectively disabled and all 32 | /// calls to `poll_can_proceed` will return `Poll::Ready(())` immediately. 33 | pub(super) fn new(duration: Duration) -> Self { 34 | let inner = if duration.is_zero() { 35 | None 36 | } else { 37 | Some(DelayedInner { 38 | duration, 39 | delay: Box::pin(sleep(duration)), 40 | delay_consumed: true, 41 | }) 42 | }; 43 | 44 | Self { inner } 45 | } 46 | 47 | /// Poll whether the operation can proceed based on the configured delay. 48 | /// 49 | /// This method implements a rate-limiting mechanism: 50 | /// 51 | /// - On first call or after a delay has elapsed, returns `Poll::Pending` 52 | /// - If called again before the delay duration has passed, returns `Poll::Pending` 53 | /// - Automatically resets the delay timer when the delay is polled again *after* it 54 | /// has previously completed. 55 | /// 56 | /// When configured with [`Duration::ZERO`], this method always returns `Poll::Ready(())`. 57 | pub(super) fn poll_can_proceed(&mut self, cx: &mut Context<'_>) -> Poll<()> { 58 | if let Some(inner) = &mut self.inner { 59 | if mem::take(&mut inner.delay_consumed) { 60 | inner.delay.as_mut().reset(Instant::now() + inner.duration); 61 | } 62 | 63 | if inner.delay.as_mut().poll(cx).is_ready() { 64 | inner.delay_consumed = true; 65 | Poll::Ready(()) 66 | } else { 67 | Poll::Pending 68 | } 69 | } else { 70 | Poll::Ready(()) 71 | } 72 | } 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use std::{ 78 | future, 79 | task::{Context, Waker}, 80 | time::Duration, 81 | }; 82 | 83 | use claims::assert_ready; 84 | use tokio::time::{Instant, sleep}; 85 | 86 | use super::Delayed; 87 | 88 | #[test] 89 | fn zero_interval_always_ready() { 90 | let mut delayed = Delayed::new(Duration::ZERO); 91 | 92 | for _ in 0..100 { 93 | let mut cx = Context::from_waker(Waker::noop()); 94 | assert_ready!(delayed.poll_can_proceed(&mut cx)); 95 | } 96 | } 97 | 98 | #[tokio::test(start_paused = true)] 99 | async fn delay_behaviour() { 100 | const INTERVAL: Duration = Duration::from_millis(250); 101 | 102 | let mut delayed = Delayed::new(INTERVAL); 103 | let before = Instant::now(); 104 | future::poll_fn(|cx| delayed.poll_can_proceed(cx)).await; 105 | assert_eq!(before.elapsed(), INTERVAL); 106 | 107 | sleep(INTERVAL * 3).await; 108 | 109 | let before = Instant::now(); 110 | future::poll_fn(|cx| delayed.poll_can_proceed(cx)).await; 111 | assert_eq!(before.elapsed(), INTERVAL); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /watermelon/src/handler/pinger.rs: -------------------------------------------------------------------------------- 1 | use std::{pin::Pin, task::Context, time::Duration}; 2 | 3 | use tokio::time::{Instant, Sleep, sleep}; 4 | 5 | const INTERVAL: Duration = Duration::from_secs(10); 6 | 7 | #[derive(Debug)] 8 | pub(super) struct Pinger { 9 | interval: Pin>, 10 | /// Number of unacknowledged pings 11 | pending_pings: u8, 12 | } 13 | 14 | #[derive(Debug)] 15 | pub(super) enum PingOutcome { 16 | /// Either the PING was sent or no ping was needed 17 | Ok, 18 | /// Too many unacknowledged pings are in flight 19 | TooManyInFlightPings, 20 | } 21 | 22 | impl Pinger { 23 | /// Create a new pinger 24 | pub(super) fn new() -> Self { 25 | Self { 26 | interval: Box::pin(sleep(INTERVAL)), 27 | pending_pings: 0, 28 | } 29 | } 30 | 31 | /// Reset the ping timer to trigger after the full interval 32 | pub(super) fn reset(&mut self) { 33 | self.interval.as_mut().reset(Instant::now() + INTERVAL); 34 | } 35 | 36 | /// Handle a received PONG response 37 | pub(super) fn handle_pong(&mut self) { 38 | self.pending_pings = self.pending_pings.saturating_sub(1); 39 | } 40 | 41 | /// Poll for ping readiness and send a PING if the interval has elapsed 42 | pub(super) fn poll(&mut self, cx: &mut Context<'_>, send_ping: impl FnOnce()) -> PingOutcome { 43 | if self.interval.as_mut().poll(cx).is_pending() { 44 | PingOutcome::Ok 45 | } else { 46 | self.do_ping(cx, send_ping) 47 | } 48 | } 49 | 50 | #[cold] 51 | fn do_ping(&mut self, cx: &mut Context<'_>, send_ping: impl FnOnce()) -> PingOutcome { 52 | if self.pending_pings < 2 { 53 | send_ping(); 54 | self.pending_pings += 1; 55 | 56 | // register the waker for the next ping 57 | loop { 58 | self.reset(); 59 | if self.interval.as_mut().poll(cx).is_pending() { 60 | break; 61 | } 62 | } 63 | 64 | PingOutcome::Ok 65 | } else { 66 | PingOutcome::TooManyInFlightPings 67 | } 68 | } 69 | } 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use std::{ 74 | task::{Context, Waker}, 75 | time::Duration, 76 | }; 77 | 78 | use claims::assert_matches; 79 | use tokio::time::advance; 80 | 81 | use crate::handler::PingOutcome; 82 | 83 | use super::Pinger; 84 | 85 | #[tokio::test(start_paused = true)] 86 | async fn e2e_ping() { 87 | let mut cx = Context::from_waker(Waker::noop()); 88 | 89 | let mut pinger = Pinger::new(); 90 | 91 | // does nothing initially 92 | assert_matches!(pinger.poll(&mut cx, || unreachable!()), PingOutcome::Ok); 93 | assert_eq!(0, pinger.pending_pings); 94 | 95 | // pings 96 | advance(Duration::from_secs(10)).await; 97 | let mut has_pinged = false; 98 | assert_matches!( 99 | pinger.poll(&mut cx, || { has_pinged = true }), 100 | PingOutcome::Ok 101 | ); 102 | assert!(has_pinged); 103 | assert_eq!(1, pinger.pending_pings); 104 | 105 | // does nothing again 106 | assert_matches!(pinger.poll(&mut cx, || unreachable!()), PingOutcome::Ok); 107 | assert_eq!(1, pinger.pending_pings); 108 | 109 | // pings again 110 | advance(Duration::from_secs(10)).await; 111 | let mut has_pinged = false; 112 | assert_matches!( 113 | pinger.poll(&mut cx, || { has_pinged = true }), 114 | PingOutcome::Ok 115 | ); 116 | assert!(has_pinged); 117 | assert_eq!(2, pinger.pending_pings); 118 | 119 | // receive PONG 120 | pinger.handle_pong(); 121 | assert_eq!(1, pinger.pending_pings); 122 | 123 | // let some time go by and reset 124 | advance(Duration::from_secs(5)).await; 125 | pinger.reset(); 126 | 127 | // make sure we don't ping again given that not enough time has elapsed 128 | // since the reset 129 | advance(Duration::from_secs(5)).await; 130 | assert_matches!(pinger.poll(&mut cx, || unreachable!()), PingOutcome::Ok); 131 | assert_eq!(1, pinger.pending_pings); 132 | 133 | // pings again 134 | advance(Duration::from_secs(5)).await; 135 | let mut has_pinged = false; 136 | assert_matches!( 137 | pinger.poll(&mut cx, || { has_pinged = true }), 138 | PingOutcome::Ok 139 | ); 140 | assert!(has_pinged); 141 | assert_eq!(2, pinger.pending_pings); 142 | 143 | // reaches too many in flight pings 144 | advance(Duration::from_secs(10)).await; 145 | assert_matches!( 146 | pinger.poll(&mut cx, || unreachable!()), 147 | PingOutcome::TooManyInFlightPings 148 | ); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /watermelon/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_code)] 2 | 3 | pub use watermelon_proto as proto; 4 | 5 | mod client; 6 | mod handler; 7 | mod multiplexed_subscription; 8 | mod subscription; 9 | #[cfg(test)] 10 | pub(crate) mod tests; 11 | mod util; 12 | 13 | pub mod core { 14 | //! NATS Core functionality implementation 15 | 16 | pub use crate::client::{Client, ClientBuilder, Echo, QuickInfo}; 17 | pub(crate) use crate::multiplexed_subscription::MultiplexedSubscription; 18 | pub use crate::subscription::Subscription; 19 | pub use watermelon_mini::AuthenticationMethod; 20 | 21 | pub mod publish { 22 | //! Utilities for publishing messages 23 | 24 | pub use crate::client::{ 25 | ClientPublish, DoClientPublish, DoOwnedClientPublish, OwnedClientPublish, Publish, 26 | PublishBuilder, 27 | }; 28 | } 29 | 30 | pub mod request { 31 | //! Utilities for publishing messages and awaiting for a response 32 | 33 | pub use crate::client::{ 34 | ClientRequest, DoClientRequest, DoOwnedClientRequest, OwnedClientRequest, Request, 35 | RequestBuilder, ResponseFut, 36 | }; 37 | } 38 | 39 | pub mod error { 40 | //! NATS Core specific errors 41 | 42 | pub use crate::{ 43 | client::{ClientClosedError, ResponseError, TryCommandError}, 44 | handler::ConnectHandlerError, 45 | }; 46 | pub use watermelon_mini::ConnectError; 47 | } 48 | } 49 | 50 | pub mod jetstream { 51 | //! NATS Jetstream functionality implementation 52 | //! 53 | //! Relies on NATS Core to communicate with the NATS server 54 | 55 | pub use crate::client::{ 56 | AckPolicy, Compression, Consumer, ConsumerBatch, ConsumerConfig, ConsumerDurability, 57 | ConsumerSpecificConfig, ConsumerStorage, ConsumerStream, Consumers, DeliverPolicy, 58 | DiscardPolicy, JetstreamClient, ReplayPolicy, RetentionPolicy, Storage, Stream, 59 | StreamConfig, StreamState, Streams, 60 | }; 61 | 62 | pub mod error { 63 | //! NATS Jetstream specific errors 64 | 65 | pub use crate::client::{ 66 | ConsumerStreamError, JetstreamApiError, JetstreamError, JetstreamErrorCode, 67 | }; 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /watermelon/src/multiplexed_subscription.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | use tokio::sync::oneshot; 8 | use watermelon_proto::{ServerMessage, Subject}; 9 | 10 | use crate::{client::ClientClosedError, core::Client}; 11 | 12 | #[derive(Debug)] 13 | pub(crate) struct MultiplexedSubscription { 14 | subscription: Option, 15 | } 16 | 17 | #[derive(Debug)] 18 | struct Inner { 19 | reply_subject: Subject, 20 | receiver: oneshot::Receiver, 21 | client: Client, 22 | } 23 | 24 | impl MultiplexedSubscription { 25 | pub(crate) fn new( 26 | reply_subject: Subject, 27 | receiver: oneshot::Receiver, 28 | client: Client, 29 | ) -> Self { 30 | Self { 31 | subscription: Some(Inner { 32 | reply_subject, 33 | receiver, 34 | client, 35 | }), 36 | } 37 | } 38 | } 39 | 40 | impl Future for MultiplexedSubscription { 41 | type Output = Result; 42 | 43 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 44 | let subscription = self 45 | .subscription 46 | .as_mut() 47 | .expect("MultiplexedSubscription polled after completing"); 48 | 49 | match Pin::new(&mut subscription.receiver).poll(cx) { 50 | Poll::Pending => Poll::Pending, 51 | Poll::Ready(result) => { 52 | self.subscription = None; 53 | Poll::Ready(result.map_err(|_| ClientClosedError)) 54 | } 55 | } 56 | } 57 | } 58 | 59 | impl Drop for MultiplexedSubscription { 60 | fn drop(&mut self) { 61 | let Some(subscription) = self.subscription.take() else { 62 | return; 63 | }; 64 | 65 | subscription 66 | .client 67 | .lazy_unsubscribe_multiplexed(subscription.reply_subject); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /watermelon/src/tests.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | sync::Arc, 3 | task::{Wake, Waker}, 4 | }; 5 | 6 | use crate::util::atomic::{AtomicUsize, Ordering}; 7 | 8 | #[derive(Debug)] 9 | pub(crate) struct FlagWaker(AtomicUsize); 10 | 11 | impl FlagWaker { 12 | pub(crate) fn new() -> (Arc, Waker) { 13 | let this = Arc::new(Self(AtomicUsize::new(0))); 14 | let waker = Waker::from(Arc::clone(&this)); 15 | (this, waker) 16 | } 17 | 18 | pub(crate) fn wakes(&self) -> usize { 19 | self.0.load(Ordering::Acquire) 20 | } 21 | } 22 | 23 | impl Wake for FlagWaker { 24 | fn wake(self: Arc) { 25 | self.0.fetch_add(1, Ordering::AcqRel); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /watermelon/src/util/atomic.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "portable-atomic")] 2 | pub(crate) use portable_atomic::*; 3 | #[cfg(not(feature = "portable-atomic"))] 4 | pub(crate) use std::sync::atomic::*; 5 | -------------------------------------------------------------------------------- /watermelon/src/util/future.rs: -------------------------------------------------------------------------------- 1 | use std::{future::Future, pin::Pin}; 2 | 3 | /// An alternative to [`futures_core::future::BoxFuture`] that is also `Sync` 4 | pub(crate) type BoxFuture<'a, T> = Pin + Send + Sync + 'a>>; 5 | -------------------------------------------------------------------------------- /watermelon/src/util/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) use self::future::BoxFuture; 2 | 3 | pub(crate) mod atomic; 4 | mod future; 5 | --------------------------------------------------------------------------------