├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── examples ├── monitor.rs ├── psubscribe.rs ├── realistic.rs └── subscribe.rs └── src ├── client ├── builder.rs ├── connect.rs ├── mod.rs ├── paired.rs └── pubsub │ ├── inner.rs │ └── mod.rs ├── error.rs ├── lib.rs ├── reconnect.rs └── resp.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Continuous integration 4 | 5 | jobs: 6 | ci: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | rust: 11 | - stable 12 | - beta 13 | - nightly 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - uses: actions-rs/toolchain@v1 19 | with: 20 | profile: minimal 21 | toolchain: ${{ matrix.rust }} 22 | override: true 23 | components: rustfmt, clippy 24 | 25 | - uses: supercharge/redis-github-action@1.1.0 26 | with: 27 | redis-version: 6 28 | 29 | - uses: actions-rs/cargo@v1 30 | with: 31 | command: build 32 | 33 | - uses: actions-rs/cargo@v1 34 | with: 35 | command: test 36 | 37 | - uses: actions-rs/cargo@v1 38 | with: 39 | command: fmt 40 | args: --all -- --check 41 | 42 | - uses: actions-rs/cargo@v1 43 | with: 44 | command: clippy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | **/*.rs.bk 3 | Cargo.lock 4 | dump.rdb 5 | .vscode/ 6 | .idea/ -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "redis-async" 3 | version = "0.17.2" 4 | authors = ["Ben Ashford "] 5 | license = "MIT/Apache-2.0" 6 | readme = "README.md" 7 | description = "An asynchronous futures based Redis client for Rust using Tokio" 8 | repository = "https://github.com/benashford/redis-async-rs" 9 | keywords = ["redis", "tokio"] 10 | edition = "2021" 11 | 12 | [dependencies] 13 | bytes = "1.0" 14 | futures-channel = "^0.3.7" 15 | futures-sink = "^0.3.7" 16 | futures-util = { version = "^0.3.7", features = ["sink"] } 17 | log = "^0.4.11" 18 | native-tls = { version = "0.2", optional = true } 19 | pin-project = "1.0" 20 | socket2 = { version = "0.5", features = ["all"] } 21 | tokio = { version = "1.0", features = ["rt", "net", "time"] } 22 | tokio-native-tls = { version = "0.3.0", optional = true } 23 | tokio-rustls = { version = "0.26", optional = true } 24 | tokio-util = { version = "0.7", features = ["codec"] } 25 | webpki-roots = { version = "0.26", optional = true } 26 | 27 | [features] 28 | default = [] 29 | tls = [] 30 | with-rustls = ["tokio-rustls", "tls", "webpki-roots"] 31 | with-native-tls = ["native-tls", "tokio-native-tls", "tls"] 32 | 33 | [dev-dependencies] 34 | env_logger = "0.11" 35 | futures = "^0.3.7" 36 | tokio = { version = "1.0", features = ["full"] } 37 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright 2016 rust-lazysort developers 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # redis-async 2 | 3 | [![](http://meritbadge.herokuapp.com/redis-async)](https://crates.io/crates/redis-async) 4 | [![](https://img.shields.io/crates/d/redis-async.svg)](https://crates.io/crates/redis-async) 5 | [![](https://img.shields.io/crates/dv/redis-async.svg)](https://crates.io/crates/redis-async) 6 | [![](https://docs.rs/redis-async/badge.svg)](https://docs.rs/redis-async/) 7 | 8 | Using Tokio and Rust's futures to create an asynchronous Redis client. [Documentation](https://docs.rs/redis-async/) 9 | 10 | ## Releases 11 | 12 | The API is currently low-level and still subject to change. 13 | 14 | Initially I'm focussing on single-server Redis instances, another long-term goal is to support Redis clusters. This would make the implementation more complex as it requires routing, and handling error conditions such as `MOVED`. 15 | 16 | ### Recent changes 17 | 18 | Version 0.14 introduces experimental TLS support, use feature flag `with-rustls` for Rustls support, or `with-native-tls` for native TLS support. There are other minor changes to the public API to enable this, in particular separate `host` and `port` arguments are required rather than a single `addr` argument. 19 | 20 | ## Other clients 21 | 22 | When starting this library there weren't any other Redis clients that used Tokio. However the current situation is more competitive: 23 | 24 | - Redis-RS - https://github.com/mitsuhiko/redis-rs - the oldest Redis client for Rust now supports asynchronous operations using Tokio. 25 | - Fred - https://github.com/azuqua/fred.rs - this also supports Redis clusters. 26 | 27 | ## Usage 28 | 29 | There are three functions in `redis_async::client` which provide functionality. One is a low-level interface, a second is a high-level interface, the third is dedicated to PUBSUB functionality. 30 | 31 | ### Low-level interface 32 | 33 | The function `client::connect` returns a future that resolves to a connection which implements both `Sink` and `Stream`. These work independently of one another to allow pipelining. It is the responsibility of the caller to match responses to requests. It is also the responsibility of the client to convert application data into instances of `resp::RespValue` and back (there are conversion traits available for common examples). 34 | 35 | This is a very low-level API compared to most Redis clients, but is done so intentionally, for two reasons: 1) it is the common demoniator between a functional Redis client (i.e. is able to support all types of requests, including those that block and have streaming responses), and 2) it results in clean `Sink`s and `Stream`s which will be composable with other Tokio-based libraries. 36 | 37 | This low-level connection will be permanently closed if the connection with the Redis server is lost, it is the responsibility of the caller to handle this and re-connect if necessary. 38 | 39 | For most practical purposes this low-level interface will not be used, the only exception possibly being the [`MONITOR`](https://redis.io/commands/monitor) command. 40 | 41 | #### Example 42 | 43 | An example of this low-level interface is in [`examples/monitor.rs`](examples/monitor.rs). This can be run with `cargo run --example monitor`, it will run until it is `Ctrl-C`'d and will show every command run against the Redis server. 44 | 45 | ### High-level interface 46 | 47 | `client::paired_connect` is used for most Redis commands (those for which one command returns one response, it's not suitable for PUBSUB, `MONITOR` or other similar commands). It allows a Redis command to be sent and a Future returned for each command. 48 | 49 | Commands will be sent in the order that `send` is called, regardless of how the future is realised. This is to allow us to take advantage of Redis's features by implicitly pipelining commands where appropriate. One side-effect of this is that for many commands, e.g. `SET` we don't need to realise the future at all, it can be assumed to be fire-and-forget; but, the final future of the final command does need to be realised (at least) to ensure that the correct behaviour is observed. 50 | 51 | In the event of a failure of communication to the Redis server, this connect will attempt to reconnect. Commands will not be automatically re-tried, however; it is for calling code to handle this and decide whether a particular command should be retried or not. 52 | 53 | #### Example 54 | 55 | See [`examples/realistic.rs`](examples/realistic.rs) for an example using completely artificial test data, it is realistic in the sense that it simulates a real-world pattern where certain operations depend on the results of others. 56 | 57 | This shows that the code can be written in a straight line fashion - iterate through the outer-loop, for each make a call to `INCR` a value and use the result to write the data to a unique key. But when run, the various calls will be pipelined. 58 | 59 | In order to test this, a tool like ngrep can be used to monitor the data sent to Redis, so running `cargo run --release --example realistic` (the `--release` flag needs to be set for the buffers to fill faster than packets can be sent to the Redis server) shows the data flowing: 60 | 61 | ``` 62 | interface: lo0 (127.0.0.0/255.0.0.0) 63 | filter: (ip or ip6) and ( port 6379 ) 64 | ##### 65 | T 127.0.0.1:61112 -> 127.0.0.1:6379 [AP] 66 | *2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18.. 67 | realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr. 68 | .*2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18. 69 | .realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr..*2..$4..INCR..$18..realistic_test_ctr 70 | .. 71 | ## 72 | T 127.0.0.1:6379 -> 127.0.0.1:61112 [AP] 73 | :1..:2..:3..:4..:5..:6..:7..:8..:9..:10.. 74 | ## 75 | T 127.0.0.1:61112 -> 127.0.0.1:6379 [AP] 76 | *3..$3..SET..$4..rt_1..$1..0..*3..$3..SET..$1..0..$4..rt_1..*3..$3..SET..$4..rt_2..$1..1..*3..$3. 77 | .SET..$1..1..$4..rt_2..*3..$3..SET..$4..rt_3..$1..2..*3..$3..SET..$1..2..$4..rt_3..*3..$3..SET..$ 78 | 4..rt_4..$1..3..*3..$3..SET..$1..3..$4..rt_4..*3..$3..SET..$4..rt_5..$1..4..*3..$3..SET..$1..4..$ 79 | 4..rt_5..*3..$3..SET..$4..rt_6..$1..5..*3..$3..SET..$1..5..$4..rt_6..*3..$3..SET..$4..rt_7..$1..6 80 | ..*3..$3..SET..$1..6..$4..rt_7..*3..$3..SET..$4..rt_8..$1..7..*3..$3..SET..$1..7..$4..rt_8..*3..$ 81 | 3..SET..$4..rt_9..$1..8..*3..$3..SET..$1..8..$4..rt_9..*3..$3..SET..$5..rt_10..$1..9..*3..$3..SET 82 | ..$1..9..$5..rt_10.. 83 | ## 84 | T 127.0.0.1:6379 -> 127.0.0.1:61112 [AP] 85 | +OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+OK..+O 86 | K.. 87 | ``` 88 | 89 | See note on 'Performance' for what impact this has. 90 | 91 | ### PUBSUB 92 | 93 | PUBSUB in Redis works differently. A connection will subscribe to one or more topics, then receive all messages that are published to that topic. As such the single-request/single-response model of `paired_connect` will not work. A specific `client::pubsub_connect` is provided for this purpose. 94 | 95 | It returns a future which resolves to a `PubsubConnection`, this provides a `subscribe` function that takes a topic as a parameter and returns a future which, once the subscription is confirmed, resolves to a stream that contains all messages published to that topic. 96 | 97 | In the event of a broken connection to the Redis server, this connection will attempt to reconnect. Any existing subscriptions, however, will be terminated, it is the responsibility of the calling code to re-subscribe to topics as necessary. 98 | 99 | #### Example 100 | 101 | See an [`examples/subscribe.rs`](examples/subscribe.rs). This will listen on a topic (by default: `test-topic`) and print each message as it arrives. To run this example: `cargo run --example subscribe` then in a separate terminal open `redis-cli` to the same server and publish some messages (e.g. `PUBLISH test-topic TESTING`). 102 | 103 | ## Performance 104 | 105 | I've removed the benchmarks from this project, as the examples were all out-of-date. I intend, at some point, to create a separate benchmarking repository which can more fairly do side-by-side performance tests of this and other Redis clients. 106 | 107 | ## Next steps 108 | 109 | - Better documentation 110 | - Test all Redis commands 111 | - Decide on best way of supporting [Redis transactions](https://redis.io/topics/transactions) 112 | - Decide on best way of supporting blocking Redis commands 113 | - Ensure all edge-cases are complete (e.g. Redis commands that return sets, nil, etc.) 114 | - Comprehensive benchmarking against other Redis clients 115 | 116 | ## License 117 | 118 | Licensed under either of 119 | 120 | - Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 121 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 122 | 123 | at your option. 124 | 125 | ### Contribution 126 | 127 | Unless you explicitly state otherwise, any contribution intentionally submitted 128 | for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any 129 | additional terms or conditions. 130 | -------------------------------------------------------------------------------- /examples/monitor.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::env; 12 | 13 | use futures::{sink::SinkExt, stream::StreamExt}; 14 | use redis_async::{client, resp_array}; 15 | 16 | #[tokio::main] 17 | async fn main() { 18 | let addr = env::args() 19 | .nth(1) 20 | .unwrap_or_else(|| "127.0.0.1".to_string()); 21 | 22 | #[cfg(not(feature = "tls"))] 23 | let mut connection = client::connect(&addr, 6379, None, None) 24 | .await 25 | .expect("Cannot connect to Redis"); 26 | 27 | #[cfg(feature = "tls")] 28 | let mut connection = client::connect_tls(&addr, 6379, None, None) 29 | .await 30 | .expect("Cannot connect to Redis"); 31 | 32 | connection 33 | .send(resp_array!["MONITOR"]) 34 | .await 35 | .expect("Cannot send MONITOR command"); 36 | 37 | let mut skip_one = connection.skip(1); 38 | 39 | while let Some(incoming) = skip_one.next().await { 40 | println!("{:?}", incoming.expect("Cannot read incoming value")); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /examples/psubscribe.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2022 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::env; 12 | 13 | use futures::StreamExt; 14 | 15 | use redis_async::{client, resp::FromResp}; 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | env_logger::init(); 20 | let topic = env::args().nth(1).unwrap_or_else(|| "test.*".to_string()); 21 | let addr = env::args() 22 | .nth(2) 23 | .unwrap_or_else(|| "127.0.0.1".to_string()); 24 | 25 | let pubsub_con = client::pubsub_connect(addr, 6379) 26 | .await 27 | .expect("Cannot connect to Redis"); 28 | let mut msgs = pubsub_con 29 | .psubscribe(&topic) 30 | .await 31 | .expect("Cannot subscribe to topic"); 32 | 33 | while let Some(message) = msgs.next().await { 34 | match message { 35 | Ok(message) => println!("{}", String::from_resp(message).unwrap()), 36 | Err(e) => { 37 | eprintln!("ERROR: {}", e); 38 | break; 39 | } 40 | } 41 | } 42 | 43 | println!("The end"); 44 | } 45 | -------------------------------------------------------------------------------- /examples/realistic.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2022 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::env; 12 | 13 | use futures_util::future; 14 | 15 | // use futures::{future, Future}; 16 | 17 | use redis_async::{client, resp_array}; 18 | 19 | /// An artificial "realistic" non-trivial example to demonstrate usage 20 | #[tokio::main] 21 | async fn main() { 22 | // Create some completely arbitrary "test data" 23 | let test_data_size = 10; 24 | 25 | let addr = env::args() 26 | .nth(1) 27 | .unwrap_or_else(|| "127.0.0.1".to_string()); 28 | 29 | let connection = client::paired_connect(addr, 6379) 30 | .await 31 | .expect("Cannot open connection"); 32 | 33 | let futures = (0..test_data_size).map(|x| (x, x.to_string())).map(|data| { 34 | let connection_inner = connection.clone(); 35 | let incr_f = connection.send(resp_array!["INCR", "realistic_test_ctr"]); 36 | async move { 37 | let ctr: String = incr_f.await.expect("Cannot increment"); 38 | 39 | let key = format!("rt_{}", ctr); 40 | let d_val = data.0.to_string(); 41 | connection_inner.send_and_forget(resp_array!["SET", &key, d_val]); 42 | connection_inner 43 | .send(resp_array!["SET", data.1, key]) 44 | .await 45 | .expect("Cannot set") 46 | } 47 | }); 48 | let result: Vec = future::join_all(futures).await; 49 | println!("RESULT: {:?}", result); 50 | assert_eq!(result.len(), test_data_size); 51 | } 52 | -------------------------------------------------------------------------------- /examples/subscribe.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2022 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::env; 12 | 13 | use futures::StreamExt; 14 | 15 | use redis_async::{client, resp::FromResp}; 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | env_logger::init(); 20 | let topic = env::args() 21 | .nth(1) 22 | .unwrap_or_else(|| "test-topic".to_string()); 23 | let addr = env::args() 24 | .nth(2) 25 | .unwrap_or_else(|| "127.0.0.1".to_string()); 26 | 27 | let pubsub_con = client::pubsub_connect(addr, 6379) 28 | .await 29 | .expect("Cannot connect to Redis"); 30 | let mut msgs = pubsub_con 31 | .subscribe(&topic) 32 | .await 33 | .expect("Cannot subscribe to topic"); 34 | 35 | while let Some(message) = msgs.next().await { 36 | match message { 37 | Ok(message) => println!("{}", String::from_resp(message).unwrap()), 38 | Err(e) => { 39 | eprintln!("ERROR: {}", e); 40 | break; 41 | } 42 | } 43 | } 44 | 45 | println!("The end"); 46 | } 47 | -------------------------------------------------------------------------------- /src/client/builder.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::sync::Arc; 12 | use std::time::Duration; 13 | 14 | use crate::error; 15 | 16 | #[derive(Debug)] 17 | /// Connection builder 18 | pub struct ConnectionBuilder { 19 | pub(crate) host: String, 20 | pub(crate) port: u16, 21 | pub(crate) username: Option>, 22 | pub(crate) password: Option>, 23 | #[cfg(feature = "tls")] 24 | pub(crate) tls: bool, 25 | pub(crate) socket_keepalive: Option, 26 | pub(crate) socket_timeout: Option, 27 | } 28 | 29 | const DEFAULT_KEEPALIVE: Duration = Duration::from_secs(60); 30 | const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); 31 | 32 | impl ConnectionBuilder { 33 | pub fn new(host: impl Into, port: u16) -> Result { 34 | Ok(Self { 35 | host: host.into(), 36 | port, 37 | username: None, 38 | password: None, 39 | #[cfg(feature = "tls")] 40 | tls: false, 41 | socket_keepalive: Some(DEFAULT_KEEPALIVE), 42 | socket_timeout: Some(DEFAULT_TIMEOUT), 43 | }) 44 | } 45 | 46 | /// Set the username used when connecting 47 | pub fn password>>(&mut self, password: V) -> &mut Self { 48 | self.password = Some(password.into()); 49 | self 50 | } 51 | 52 | /// Set the password used when connecting 53 | pub fn username>>(&mut self, username: V) -> &mut Self { 54 | self.username = Some(username.into()); 55 | self 56 | } 57 | 58 | #[cfg(feature = "tls")] 59 | pub fn tls(&mut self) -> &mut Self { 60 | self.tls = true; 61 | self 62 | } 63 | 64 | /// Set the socket keepalive duration 65 | pub fn socket_keepalive(&mut self, duration: Option) -> &mut Self { 66 | self.socket_keepalive = duration; 67 | self 68 | } 69 | 70 | /// Set the socket timeout duration 71 | pub fn socket_timeout(&mut self, duration: Option) -> &mut Self { 72 | self.socket_timeout = duration; 73 | self 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/client/connect.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::time::Duration; 12 | 13 | use futures_util::{SinkExt, StreamExt}; 14 | use pin_project::pin_project; 15 | use tokio::{ 16 | io::{AsyncRead, AsyncWrite}, 17 | net::TcpStream, 18 | }; 19 | use tokio_util::codec::{Decoder, Framed}; 20 | 21 | use crate::{ 22 | error, 23 | resp::{self, RespCodec}, 24 | }; 25 | 26 | #[pin_project(project = RespConnectionInnerProj)] 27 | pub enum RespConnectionInner { 28 | #[cfg(feature = "with-rustls")] 29 | Tls { 30 | #[pin] 31 | stream: tokio_rustls::client::TlsStream, 32 | }, 33 | #[cfg(feature = "with-native-tls")] 34 | Tls { 35 | #[pin] 36 | stream: tokio_native_tls::TlsStream, 37 | }, 38 | Plain { 39 | #[pin] 40 | stream: TcpStream, 41 | }, 42 | } 43 | 44 | impl AsyncWrite for RespConnectionInner { 45 | fn poll_write( 46 | self: std::pin::Pin<&mut Self>, 47 | cx: &mut std::task::Context<'_>, 48 | buf: &[u8], 49 | ) -> std::task::Poll> { 50 | let this = self.project(); 51 | match this { 52 | #[cfg(feature = "tls")] 53 | RespConnectionInnerProj::Tls { stream } => stream.poll_write(cx, buf), 54 | RespConnectionInnerProj::Plain { stream } => stream.poll_write(cx, buf), 55 | } 56 | } 57 | 58 | fn poll_flush( 59 | self: std::pin::Pin<&mut Self>, 60 | cx: &mut std::task::Context<'_>, 61 | ) -> std::task::Poll> { 62 | let this = self.project(); 63 | match this { 64 | #[cfg(feature = "tls")] 65 | RespConnectionInnerProj::Tls { stream } => stream.poll_flush(cx), 66 | RespConnectionInnerProj::Plain { stream } => stream.poll_flush(cx), 67 | } 68 | } 69 | 70 | fn poll_shutdown( 71 | self: std::pin::Pin<&mut Self>, 72 | cx: &mut std::task::Context<'_>, 73 | ) -> std::task::Poll> { 74 | let this = self.project(); 75 | match this { 76 | #[cfg(feature = "tls")] 77 | RespConnectionInnerProj::Tls { stream } => stream.poll_shutdown(cx), 78 | RespConnectionInnerProj::Plain { stream } => stream.poll_shutdown(cx), 79 | } 80 | } 81 | } 82 | 83 | impl AsyncRead for RespConnectionInner { 84 | fn poll_read( 85 | self: std::pin::Pin<&mut Self>, 86 | cx: &mut std::task::Context<'_>, 87 | buf: &mut tokio::io::ReadBuf<'_>, 88 | ) -> std::task::Poll> { 89 | let this = self.project(); 90 | match this { 91 | #[cfg(feature = "tls")] 92 | RespConnectionInnerProj::Tls { stream } => stream.poll_read(cx, buf), 93 | RespConnectionInnerProj::Plain { stream } => stream.poll_read(cx, buf), 94 | } 95 | } 96 | } 97 | 98 | pub type RespConnection = Framed; 99 | 100 | /// Connect to a Redis server and return a Future that resolves to a 101 | /// `RespConnection` for reading and writing asynchronously. 102 | /// 103 | /// Each `RespConnection` implements both `Sink` and `Stream` and read and 104 | /// writes `RESP` objects. 105 | /// 106 | /// This is a low-level interface to enable the creation of higher-level 107 | /// functionality. 108 | /// 109 | /// The sink and stream sides behave independently of each other, it is the 110 | /// responsibility of the calling application to determine what results are 111 | /// paired to a particular command. 112 | /// 113 | /// But since most Redis usages involve issue commands that result in one 114 | /// single result, this library also implements `paired_connect`. 115 | pub async fn connect( 116 | host: &str, 117 | port: u16, 118 | socket_keepalive: Option, 119 | socket_timeout: Option, 120 | ) -> Result { 121 | let tcp_stream = TcpStream::connect((host, port)).await?; 122 | apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?; 123 | Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream })) 124 | } 125 | 126 | #[cfg(feature = "with-rustls")] 127 | pub async fn connect_tls( 128 | host: &str, 129 | port: u16, 130 | socket_keepalive: Option, 131 | socket_timeout: Option, 132 | ) -> Result { 133 | use std::sync::Arc; 134 | use tokio_rustls::{ 135 | rustls::{ClientConfig, RootCertStore}, 136 | TlsConnector, 137 | }; 138 | 139 | let mut root_store = RootCertStore::empty(); 140 | root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); 141 | let config = ClientConfig::builder() 142 | .with_root_certificates(root_store) 143 | .with_no_client_auth(); 144 | let connector = TlsConnector::from(Arc::new(config)); 145 | let addr = 146 | tokio::net::lookup_host((host, port)) 147 | .await? 148 | .next() 149 | .ok_or(error::Error::Connection( 150 | error::ConnectionReason::ConnectionFailed, 151 | ))?; 152 | let tcp_stream = TcpStream::connect(addr).await?; 153 | apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?; 154 | 155 | let stream = connector 156 | .connect( 157 | String::from(host) 158 | .try_into() 159 | .map_err(|_err| error::Error::InvalidDnsName)?, 160 | tcp_stream, 161 | ) 162 | .await?; 163 | Ok(RespCodec.framed(RespConnectionInner::Tls { stream })) 164 | } 165 | 166 | #[cfg(feature = "with-native-tls")] 167 | pub async fn connect_tls( 168 | host: &str, 169 | port: u16, 170 | socket_keepalive: Option, 171 | socket_timeout: Option, 172 | ) -> Result { 173 | let cx = native_tls::TlsConnector::builder().build()?; 174 | let cx = tokio_native_tls::TlsConnector::from(cx); 175 | 176 | let addr = 177 | tokio::net::lookup_host((host, port)) 178 | .await? 179 | .next() 180 | .ok_or(error::Error::Connection( 181 | error::ConnectionReason::ConnectionFailed, 182 | ))?; 183 | let tcp_stream = TcpStream::connect(addr).await?; 184 | apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?; 185 | let stream = cx.connect(host, tcp_stream).await?; 186 | 187 | Ok(RespCodec.framed(RespConnectionInner::Tls { stream })) 188 | } 189 | 190 | pub async fn connect_with_auth( 191 | host: &str, 192 | port: u16, 193 | username: Option<&str>, 194 | password: Option<&str>, 195 | #[allow(unused_variables)] tls: bool, 196 | socket_keepalive: Option, 197 | socket_timeout: Option, 198 | ) -> Result { 199 | #[cfg(feature = "tls")] 200 | let mut connection = if tls { 201 | connect_tls(host, port, socket_keepalive, socket_timeout).await? 202 | } else { 203 | connect(host, port, socket_keepalive, socket_timeout).await? 204 | }; 205 | #[cfg(not(feature = "tls"))] 206 | let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?; 207 | 208 | if let Some(password) = password { 209 | let mut auth = resp_array!["AUTH"]; 210 | 211 | if let Some(username) = username { 212 | auth.push(username); 213 | } 214 | 215 | auth.push(password); 216 | 217 | connection.send(auth).await?; 218 | match connection.next().await { 219 | Some(Ok(value)) => match resp::FromResp::from_resp(value) { 220 | Ok(()) => (), 221 | Err(e) => return Err(e), 222 | }, 223 | Some(Err(e)) => return Err(e), 224 | None => { 225 | return Err(error::internal( 226 | "Connection closed before authentication complete", 227 | )) 228 | } 229 | } 230 | } 231 | 232 | Ok(connection) 233 | } 234 | 235 | /// Apply a custom keep-alive value to the connection 236 | fn apply_keepalive_and_timeouts( 237 | stream: &TcpStream, 238 | socket_keepalive: Option, 239 | socket_timeout: Option, 240 | ) -> Result<(), error::Error> { 241 | let sock_ref = socket2::SockRef::from(stream); 242 | 243 | if let Some(interval) = socket_keepalive { 244 | let keep_alive = socket2::TcpKeepalive::new() 245 | .with_time(interval) 246 | .with_interval(interval); 247 | // Not windows 248 | #[cfg(any( 249 | target_os = "android", 250 | target_os = "dragonfly", 251 | target_os = "freebsd", 252 | target_os = "fuchsia", 253 | target_os = "illumos", 254 | target_os = "ios", 255 | target_os = "linux", 256 | target_os = "macos", 257 | target_os = "netbsd", 258 | target_os = "tvos", 259 | target_os = "watchos", 260 | ))] 261 | let keep_alive = keep_alive.with_retries(1); 262 | sock_ref.set_tcp_keepalive(&keep_alive)?; 263 | } 264 | 265 | if let Some(timeout) = socket_timeout { 266 | sock_ref.set_read_timeout(Some(timeout))?; 267 | sock_ref.set_write_timeout(Some(timeout))?; 268 | } 269 | 270 | Ok(()) 271 | } 272 | 273 | #[cfg(test)] 274 | mod test { 275 | use futures_util::{ 276 | sink::SinkExt, 277 | stream::{self, StreamExt}, 278 | }; 279 | 280 | use crate::resp; 281 | 282 | #[tokio::test] 283 | async fn can_connect() { 284 | let mut connection = super::connect("127.0.0.1", 6379, None, None) 285 | .await 286 | .expect("Cannot connect"); 287 | connection 288 | .send(resp_array!["PING", "TEST"]) 289 | .await 290 | .expect("Cannot send PING"); 291 | let values: Vec<_> = connection 292 | .take(1) 293 | .map(|r| r.expect("Unexpected invalid data")) 294 | .collect() 295 | .await; 296 | 297 | assert_eq!(values.len(), 1); 298 | assert_eq!(values[0], "TEST".into()); 299 | } 300 | 301 | #[tokio::test] 302 | async fn complex_test() { 303 | let mut connection = super::connect("127.0.0.1", 6379, None, None) 304 | .await 305 | .expect("Cannot connect"); 306 | let mut ops = Vec::new(); 307 | ops.push(resp_array!["FLUSH"]); 308 | ops.extend((0..1000).map(|i| resp_array!["SADD", "test_set", format!("VALUE: {}", i)])); 309 | ops.push(resp_array!["SMEMBERS", "test_set"]); 310 | let mut ops_stream = stream::iter(ops).map(Ok); 311 | connection 312 | .send_all(&mut ops_stream) 313 | .await 314 | .expect("Cannot send"); 315 | let values: Vec<_> = connection 316 | .skip(1001) 317 | .take(1) 318 | .map(|r| r.expect("Unexpected invalid data")) 319 | .collect() 320 | .await; 321 | 322 | assert_eq!(values.len(), 1); 323 | let values = match &values[0] { 324 | resp::RespValue::Array(ref values) => values.clone(), 325 | _ => panic!("Not an array"), 326 | }; 327 | assert_eq!(values.len(), 1000); 328 | } 329 | } 330 | -------------------------------------------------------------------------------- /src/client/mod.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2020 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | //! The client API itself. 12 | //! 13 | //! This contains three main functions that return three specific types of client: 14 | //! 15 | //! * `connect` returns a pair of `Stream` and `Sink`, clients can write RESP messages to the 16 | //! `Sink` and read RESP messages from the `Stream`. Pairing requests to responses is up to the 17 | //! client. This is intended to be a low-level interface from which more user-friendly interfaces 18 | //! can be built. 19 | //! * `paired_connect` is used for most of the standard Redis commands, where one request results 20 | //! in one response. 21 | //! * `pubsub_connect` is used for Redis's PUBSUB functionality. 22 | 23 | pub mod connect; 24 | #[macro_use] 25 | pub mod paired; 26 | mod builder; 27 | pub mod pubsub; 28 | 29 | pub use self::connect::connect; 30 | #[cfg(feature = "tls")] 31 | pub use self::connect::connect_tls; 32 | 33 | pub use self::{ 34 | builder::ConnectionBuilder, 35 | paired::{paired_connect, PairedConnection}, 36 | pubsub::{pubsub_connect, PubsubConnection}, 37 | }; 38 | -------------------------------------------------------------------------------- /src/client/paired.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::collections::VecDeque; 12 | use std::future::Future; 13 | use std::marker::PhantomData; 14 | use std::mem; 15 | use std::pin::Pin; 16 | use std::sync::Arc; 17 | use std::task::{Context, Poll}; 18 | use std::time::Duration; 19 | 20 | use futures_channel::{mpsc, oneshot}; 21 | use futures_sink::Sink; 22 | use futures_util::{future::TryFutureExt, stream::StreamExt}; 23 | 24 | use super::{ 25 | connect::{connect_with_auth, RespConnection}, 26 | ConnectionBuilder, 27 | }; 28 | 29 | use crate::{ 30 | error, 31 | reconnect::{reconnect, Reconnect}, 32 | resp, 33 | }; 34 | 35 | /// The state of sending messages to a Redis server 36 | enum SendStatus { 37 | /// The connection is clear, more messages can be sent 38 | Ok, 39 | /// The connection has closed, nothing more should be sent 40 | End, 41 | /// The connection reported itself as full, it should be flushed before attempting to send the 42 | /// pending message again 43 | Full(resp::RespValue), 44 | } 45 | 46 | /// The state of receiving messages from a Redis server 47 | #[derive(Debug)] 48 | enum ReceiveStatus { 49 | /// Everything has been read, and the connection is closed, don't attempt to read any more 50 | ReadyFinished, 51 | /// Everything has been read, but the connection is open for future messages. 52 | ReadyMore, 53 | /// The connection is not ready 54 | NotReady, 55 | } 56 | 57 | type CommandResult = Result; 58 | type Responder = oneshot::Sender; 59 | type SendPayload = (resp::RespValue, Responder); 60 | 61 | // /// The PairedConnectionInner is a spawned future that is responsible for pairing commands and 62 | // /// results onto a `RespConnection` that is otherwise unpaired 63 | struct PairedConnectionInner { 64 | /// The underlying connection that talks the RESP protocol 65 | connection: RespConnection, 66 | /// The channel upon which commands are received 67 | out_rx: mpsc::UnboundedReceiver, 68 | /// The queue of waiting oneshot's for commands sent but results not yet received 69 | waiting: VecDeque, 70 | 71 | /// The status of the underlying connection 72 | send_status: SendStatus, 73 | } 74 | 75 | impl PairedConnectionInner { 76 | fn new( 77 | con: RespConnection, 78 | out_rx: mpsc::UnboundedReceiver<(resp::RespValue, Responder)>, 79 | ) -> Self { 80 | PairedConnectionInner { 81 | connection: con, 82 | out_rx, 83 | waiting: VecDeque::new(), 84 | send_status: SendStatus::Ok, 85 | } 86 | } 87 | 88 | fn impl_start_send( 89 | &mut self, 90 | cx: &mut Context, 91 | msg: resp::RespValue, 92 | ) -> Result { 93 | match Pin::new(&mut self.connection).poll_ready(cx) { 94 | Poll::Ready(Ok(())) => (), 95 | Poll::Ready(Err(e)) => return Err(e.into()), 96 | Poll::Pending => { 97 | self.send_status = SendStatus::Full(msg); 98 | return Ok(false); 99 | } 100 | } 101 | 102 | self.send_status = SendStatus::Ok; 103 | Pin::new(&mut self.connection).start_send(msg)?; 104 | Ok(true) 105 | } 106 | 107 | fn poll_start_send(&mut self, cx: &mut Context) -> Result { 108 | let mut status = SendStatus::Ok; 109 | mem::swap(&mut status, &mut self.send_status); 110 | 111 | let message = match status { 112 | SendStatus::End => { 113 | self.send_status = SendStatus::End; 114 | return Ok(false); 115 | } 116 | SendStatus::Full(msg) => msg, 117 | SendStatus::Ok => match self.out_rx.poll_next_unpin(cx) { 118 | Poll::Ready(Some((msg, tx))) => { 119 | self.waiting.push_back(tx); 120 | msg 121 | } 122 | Poll::Ready(None) => { 123 | self.send_status = SendStatus::End; 124 | return Ok(false); 125 | } 126 | Poll::Pending => return Ok(false), 127 | }, 128 | }; 129 | 130 | self.impl_start_send(cx, message) 131 | } 132 | 133 | fn poll_complete(&mut self, cx: &mut Context) -> Result<(), error::Error> { 134 | let _ = Pin::new(&mut self.connection).poll_flush(cx)?; 135 | Ok(()) 136 | } 137 | 138 | fn receive(&mut self, cx: &mut Context) -> Result { 139 | if let SendStatus::End = self.send_status { 140 | if self.waiting.is_empty() { 141 | return Ok(ReceiveStatus::ReadyFinished); 142 | } 143 | } 144 | match self.connection.poll_next_unpin(cx) { 145 | Poll::Ready(None) => Err(error::unexpected("Connection to Redis closed unexpectedly")), 146 | Poll::Ready(Some(Ok(msg))) => { 147 | let tx = match self.waiting.pop_front() { 148 | Some(tx) => tx, 149 | None => panic!("Received unexpected message: {:?}", msg), 150 | }; 151 | let _ = tx.send(Ok(msg)); 152 | Ok(ReceiveStatus::ReadyMore) 153 | } 154 | Poll::Ready(Some(Err(e))) => Err(e), 155 | Poll::Pending => Ok(ReceiveStatus::NotReady), 156 | } 157 | } 158 | 159 | fn handle_error(&mut self, e: &error::Error) { 160 | for tx in self.waiting.drain(..) { 161 | let _ = tx.send(Err(error::internal(format!( 162 | "Failed due to underlying failure: {}", 163 | e 164 | )))); 165 | } 166 | 167 | log::error!("Internal error in PairedConnectionInner: {}", e); 168 | } 169 | } 170 | 171 | impl Future for PairedConnectionInner { 172 | type Output = (); 173 | 174 | #[allow(clippy::unit_arg)] 175 | fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { 176 | let mut_self = self.get_mut(); 177 | // If there's something to send, send it... 178 | let mut sending = true; 179 | while sending { 180 | sending = match mut_self.poll_start_send(cx) { 181 | Ok(sending) => sending, 182 | Err(ref e) => return Poll::Ready(mut_self.handle_error(e)), 183 | }; 184 | } 185 | 186 | if let Err(ref e) = mut_self.poll_complete(cx) { 187 | return Poll::Ready(mut_self.handle_error(e)); 188 | }; 189 | 190 | // If there's something to receive, receive it... 191 | loop { 192 | match mut_self.receive(cx) { 193 | Ok(ReceiveStatus::NotReady) => return Poll::Pending, 194 | Ok(ReceiveStatus::ReadyMore) => (), 195 | Ok(ReceiveStatus::ReadyFinished) => return Poll::Ready(()), 196 | Err(ref e) => return Poll::Ready(mut_self.handle_error(e)), 197 | } 198 | } 199 | } 200 | } 201 | 202 | /// A shareable and cheaply cloneable connection to which Redis commands can be sent 203 | #[derive(Debug, Clone)] 204 | pub struct PairedConnection { 205 | out_tx_c: Arc>>, 206 | } 207 | 208 | async fn inner_conn_fn( 209 | host: String, 210 | port: u16, 211 | username: Option>, 212 | password: Option>, 213 | tls: bool, 214 | socket_keepalive: Option, 215 | socket_timeout: Option, 216 | ) -> Result, error::Error> { 217 | let username = username.as_ref().map(|u| u.as_ref()); 218 | let password = password.as_ref().map(|p| p.as_ref()); 219 | let connection = connect_with_auth( 220 | &host, 221 | port, 222 | username, 223 | password, 224 | tls, 225 | socket_keepalive, 226 | socket_timeout, 227 | ) 228 | .await?; 229 | let (out_tx, out_rx) = mpsc::unbounded(); 230 | let paired_connection_inner = PairedConnectionInner::new(connection, out_rx); 231 | tokio::spawn(paired_connection_inner); 232 | Ok(out_tx) 233 | } 234 | 235 | impl ConnectionBuilder { 236 | pub fn paired_connect(&self) -> impl Future> { 237 | let host = self.host.clone(); 238 | let port = self.port; 239 | let username = self.username.clone(); 240 | let password = self.password.clone(); 241 | 242 | let work_fn = |con: &mpsc::UnboundedSender, act| { 243 | con.unbounded_send(act).map_err(|e| e.into()) 244 | }; 245 | 246 | #[cfg(feature = "tls")] 247 | let tls = self.tls; 248 | #[cfg(not(feature = "tls"))] 249 | let tls = false; 250 | 251 | let socket_keepalive = self.socket_keepalive; 252 | let socket_timeout = self.socket_timeout; 253 | 254 | let conn_fn = move || { 255 | let con_f = inner_conn_fn( 256 | host.clone(), 257 | port, 258 | username.clone(), 259 | password.clone(), 260 | tls, 261 | socket_keepalive, 262 | socket_timeout, 263 | ); 264 | Box::pin(con_f) as Pin> + Send + Sync>> 265 | }; 266 | 267 | let reconnecting_con = reconnect(work_fn, conn_fn); 268 | reconnecting_con.map_ok(|con| PairedConnection { 269 | out_tx_c: Arc::new(con), 270 | }) 271 | } 272 | } 273 | 274 | /// The default starting point to use most default Redis functionality. 275 | /// 276 | /// Returns a future that resolves to a `PairedConnection`. The future will complete when the 277 | /// initial connection is established. 278 | /// 279 | /// Once the initial connection is established, the connection will attempt to reconnect should 280 | /// the connection be broken (e.g. the Redis server being restarted), but reconnections occur 281 | /// asynchronously, so all commands issued while the connection is unavailable will error, it is 282 | /// the client's responsibility to retry commands as applicable. Also, at least one command needs 283 | /// to be tried against the connection to trigger the re-connection attempt; this means at least 284 | /// one command will definitely fail in a disconnect/reconnect scenario. 285 | pub async fn paired_connect( 286 | host: impl Into, 287 | port: u16, 288 | ) -> Result { 289 | ConnectionBuilder::new(host, port)?.paired_connect().await 290 | } 291 | 292 | impl PairedConnection { 293 | /// Sends a command to Redis. 294 | /// 295 | /// The message must be in the format of a single RESP message, this can be constructed 296 | /// manually or with the `resp_array!` macro. Returned is a future that resolves to the value 297 | /// returned from Redis. The type must be one for which the `resp::FromResp` trait is defined. 298 | /// 299 | /// The future will fail for numerous reasons, including but not limited to: IO issues, conversion 300 | /// problems, and server-side errors being returned by Redis. 301 | /// 302 | /// Behind the scenes the message is queued up and sent to Redis asynchronously before the 303 | /// future is realised. As such, it is guaranteed that messages are sent in the same order 304 | /// that `send` is called. 305 | pub fn send(&self, msg: resp::RespValue) -> SendFuture 306 | where 307 | T: resp::FromResp + Unpin, 308 | { 309 | match &msg { 310 | resp::RespValue::Array(_) => (), 311 | _ => { 312 | return SendFuture::new(error::internal("Command must be a RespValue::Array")); 313 | } 314 | } 315 | 316 | let (tx, rx) = oneshot::channel(); 317 | match self.out_tx_c.do_work((msg, tx)) { 318 | Ok(()) => SendFuture::new(rx), 319 | Err(e) => SendFuture::new(e), 320 | } 321 | } 322 | 323 | #[inline] 324 | pub fn send_and_forget(&self, msg: resp::RespValue) { 325 | let send_f = self.send::(msg); 326 | let forget_f = async { 327 | if let Err(e) = send_f.await { 328 | log::error!("Error in send_and_forget: {}", e); 329 | } 330 | }; 331 | tokio::spawn(forget_f); 332 | } 333 | } 334 | 335 | #[derive(Debug)] 336 | enum SendFutureType { 337 | Wait(oneshot::Receiver>), 338 | Error(Option), 339 | } 340 | 341 | impl From>> for SendFutureType { 342 | fn from(from: oneshot::Receiver>) -> Self { 343 | Self::Wait(from) 344 | } 345 | } 346 | 347 | impl From for SendFutureType { 348 | fn from(e: error::Error) -> Self { 349 | Self::Error(Some(e)) 350 | } 351 | } 352 | 353 | #[derive(Debug)] 354 | pub struct SendFuture { 355 | send_type: SendFutureType, 356 | _phantom: PhantomData, 357 | } 358 | 359 | impl SendFuture { 360 | #[inline] 361 | fn new(send_type: impl Into) -> Self { 362 | Self { 363 | send_type: send_type.into(), 364 | _phantom: Default::default(), 365 | } 366 | } 367 | } 368 | 369 | impl Future for SendFuture 370 | where 371 | T: resp::FromResp + Unpin, 372 | { 373 | type Output = Result; 374 | 375 | #[inline] 376 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 377 | match self.get_mut().send_type { 378 | SendFutureType::Error(ref mut e) => match e.take() { 379 | Some(e) => Poll::Ready(Err(e)), 380 | None => panic!("Future polled several times after completion"), 381 | }, 382 | SendFutureType::Wait(ref mut rx) => match Pin::new(rx).poll(cx) { 383 | Poll::Ready(Ok(Ok(v))) => Poll::Ready(T::from_resp(v)), 384 | Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(e)), 385 | Poll::Ready(Err(_)) => Poll::Ready(Err(error::internal( 386 | "Connection closed before response received", 387 | ))), 388 | Poll::Pending => Poll::Pending, 389 | }, 390 | } 391 | } 392 | } 393 | 394 | #[cfg(test)] 395 | mod test { 396 | use super::ConnectionBuilder; 397 | 398 | #[tokio::test] 399 | async fn can_paired_connect() { 400 | let connection = super::paired_connect("127.0.0.1", 6379) 401 | .await 402 | .expect("Cannot establish connection"); 403 | 404 | let res_f = connection.send(resp_array!["PING", "TEST"]); 405 | connection.send_and_forget(resp_array!["SET", "X", "123"]); 406 | let wait_f = connection.send(resp_array!["GET", "X"]); 407 | 408 | let result_1: String = res_f.await.expect("Cannot read result of first thing"); 409 | let result_2: String = wait_f.await.expect("Cannot read result of second thing"); 410 | 411 | assert_eq!(result_1, "TEST"); 412 | assert_eq!(result_2, "123"); 413 | } 414 | 415 | #[tokio::test] 416 | async fn complex_paired_connect() { 417 | let connection = super::paired_connect("127.0.0.1", 6379) 418 | .await 419 | .expect("Cannot establish connection"); 420 | 421 | let value: String = connection 422 | .send(resp_array!["INCR", "CTR"]) 423 | .await 424 | .expect("Cannot increment counter"); 425 | let result: String = connection 426 | .send(resp_array!["SET", "LASTCTR", value]) 427 | .await 428 | .expect("Cannot set value"); 429 | 430 | assert_eq!(result, "OK"); 431 | } 432 | 433 | #[tokio::test] 434 | async fn sending_a_lot_of_data_test() { 435 | let connection = super::paired_connect("127.0.0.1", 6379) 436 | .await 437 | .expect("Cannot connect to Redis"); 438 | let mut futures = Vec::with_capacity(1000); 439 | for i in 0..1000 { 440 | let key = format!("X_{}", i); 441 | connection.send_and_forget(resp_array!["SET", &key, i.to_string()]); 442 | futures.push(connection.send(resp_array!["GET", key])); 443 | } 444 | let last_future = futures.remove(999); 445 | let result: String = last_future.await.expect("Cannot wait for result"); 446 | assert_eq!(result, "999"); 447 | } 448 | 449 | #[tokio::test] 450 | async fn test_builder() { 451 | let mut builder = 452 | ConnectionBuilder::new("127.0.0.1", 6379).expect("Cannot construct builder..."); 453 | builder.password("password"); 454 | builder.username(String::from("username")); 455 | let connection_result = builder.paired_connect().await; 456 | // Expecting an error as these aren't the correct username/password 457 | assert!(connection_result.is_err()); 458 | } 459 | } 460 | -------------------------------------------------------------------------------- /src/client/pubsub/inner.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2023 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::collections::BTreeMap; 12 | use std::future::Future; 13 | use std::pin::Pin; 14 | use std::task::Context; 15 | use std::task::Poll; 16 | 17 | use futures_channel::{mpsc, oneshot}; 18 | use futures_sink::Sink; 19 | use futures_util::stream::{Fuse, StreamExt}; 20 | 21 | use crate::{ 22 | client::connect::RespConnection, 23 | error::{self, ConnectionReason}, 24 | resp::{self, FromResp}, 25 | }; 26 | 27 | use super::{PubsubEvent, PubsubSink}; 28 | 29 | /// A spawned future that handles a Pub/Sub connection and routes messages to streams for 30 | /// downstream consumption 31 | pub(crate) struct PubsubConnectionInner { 32 | /// The actual Redis connection 33 | connection: RespConnection, 34 | /// A stream onto which subscription/unsubscription requests are read 35 | out_rx: Fuse>, 36 | /// Current subscriptions 37 | subscriptions: BTreeMap, 38 | psubscriptions: BTreeMap, 39 | /// Subscriptions that have not yet been confirmed 40 | pending_subs: BTreeMap)>, 41 | pending_psubs: BTreeMap)>, 42 | /// Any incomplete messages to be sent... 43 | send_pending: Option, 44 | } 45 | 46 | impl PubsubConnectionInner { 47 | pub(crate) fn new(con: RespConnection, out_rx: mpsc::UnboundedReceiver) -> Self { 48 | PubsubConnectionInner { 49 | connection: con, 50 | out_rx: out_rx.fuse(), 51 | subscriptions: BTreeMap::new(), 52 | psubscriptions: BTreeMap::new(), 53 | pending_subs: BTreeMap::new(), 54 | pending_psubs: BTreeMap::new(), 55 | send_pending: None, 56 | } 57 | } 58 | 59 | /// If an unrecoverable error occurs in the inner connection, then we call this to notify all 60 | /// subscribers. 61 | /// This function sends the error to all subscribers then returns the error itself, as an `Err` 62 | /// to enable ergonomic use in a `?` operator. 63 | fn fail_all(&self, err: error::Error) -> Result<(), error::Error> { 64 | for sender in self.subscriptions.values() { 65 | let _ = sender.unbounded_send(Err(err.clone())); 66 | } 67 | for sender in self.psubscriptions.values() { 68 | let _ = sender.unbounded_send(Err(err.clone())); 69 | } 70 | Err(err) 71 | } 72 | 73 | /// Returns `true` if data sent, or `false` if stream not ready... 74 | fn do_send(&mut self, cx: &mut Context, msg: resp::RespValue) -> Result { 75 | match Pin::new(&mut self.connection).poll_ready(cx) { 76 | Poll::Ready(_) => { 77 | Pin::new(&mut self.connection).start_send(msg)?; 78 | Ok(true) 79 | } 80 | Poll::Pending => { 81 | self.send_pending = Some(msg); 82 | Ok(false) 83 | } 84 | } 85 | } 86 | 87 | fn do_flush(&mut self, cx: &mut Context) -> Result<(), error::Error> { 88 | match Pin::new(&mut self.connection).poll_flush(cx) { 89 | Poll::Ready(r) => r.map_err(|e| e.into()), 90 | Poll::Pending => Ok(()), 91 | } 92 | } 93 | 94 | // Returns true = flushing required. false = no flushing required 95 | fn handle_new_subs(&mut self, cx: &mut Context) -> Result<(), error::Error> { 96 | if let Some(msg) = self.send_pending.take() { 97 | if !self.do_send(cx, msg)? { 98 | return Ok(()); 99 | } 100 | } 101 | loop { 102 | match self.out_rx.poll_next_unpin(cx) { 103 | Poll::Pending => return Ok(()), 104 | Poll::Ready(None) => return Ok(()), 105 | Poll::Ready(Some(pubsub_event)) => { 106 | let message = match pubsub_event { 107 | PubsubEvent::Subscribe(topic, sender, signal) => { 108 | self.pending_subs.insert(topic.clone(), (sender, signal)); 109 | resp_array!["SUBSCRIBE", topic] 110 | } 111 | PubsubEvent::Psubscribe(topic, sender, signal) => { 112 | self.pending_psubs.insert(topic.clone(), (sender, signal)); 113 | resp_array!["PSUBSCRIBE", topic] 114 | } 115 | PubsubEvent::Unsubscribe(topic) => resp_array!["UNSUBSCRIBE", topic], 116 | PubsubEvent::Punsubscribe(topic) => resp_array!["PUNSUBSCRIBE", topic], 117 | }; 118 | if !self.do_send(cx, message)? { 119 | return Ok(()); 120 | } 121 | } 122 | } 123 | } 124 | } 125 | 126 | fn handle_message(&mut self, msg: resp::RespValue) -> Result<(), error::Error> { 127 | let (message_type, topic, msg) = match msg { 128 | resp::RespValue::Array(mut messages) => match ( 129 | messages.pop(), 130 | messages.pop(), 131 | messages.pop(), 132 | messages.pop(), 133 | ) { 134 | (Some(msg), Some(topic), Some(message_type), None) => { 135 | match (msg, String::from_resp(topic), message_type) { 136 | (msg, Ok(topic), resp::RespValue::BulkString(bytes)) => (bytes, topic, msg), 137 | _ => return Err(error::unexpected("Incorrect format of a PUBSUB message")), 138 | } 139 | } 140 | (Some(msg), Some(_), Some(topic), Some(message_type)) => { 141 | match (msg, String::from_resp(topic), message_type) { 142 | (msg, Ok(topic), resp::RespValue::BulkString(bytes)) => (bytes, topic, msg), 143 | _ => return Err(error::unexpected("Incorrect format of a PUBSUB message")), 144 | } 145 | } 146 | _ => { 147 | return Err(error::unexpected( 148 | "Wrong number of parts for a PUBSUB message", 149 | )); 150 | } 151 | }, 152 | resp::RespValue::Error(msg) => { 153 | return Err(error::unexpected(format!("Error from server: {}", msg))); 154 | } 155 | other => { 156 | return Err(error::unexpected(format!( 157 | "PUBSUB message should be encoded as an array, actual: {other:?}", 158 | ))); 159 | } 160 | }; 161 | 162 | match message_type.as_slice() { 163 | b"subscribe" => match self.pending_subs.remove(&topic) { 164 | Some((sender, signal)) => { 165 | self.subscriptions.insert(topic, sender); 166 | signal 167 | .send(()) 168 | .map_err(|()| error::internal("Error confirming subscription"))? 169 | } 170 | None => { 171 | return Err(error::internal(format!( 172 | "Received unexpected subscribe notification for topic: {}", 173 | topic 174 | ))); 175 | } 176 | }, 177 | b"psubscribe" => match self.pending_psubs.remove(&topic) { 178 | Some((sender, signal)) => { 179 | self.psubscriptions.insert(topic, sender); 180 | signal 181 | .send(()) 182 | .map_err(|()| error::internal("Error confirming subscription"))? 183 | } 184 | None => { 185 | return Err(error::internal(format!( 186 | "Received unexpected subscribe notification for topic: {}", 187 | topic 188 | ))); 189 | } 190 | }, 191 | b"unsubscribe" => { 192 | if self.subscriptions.remove(&topic).is_none() { 193 | log::warn!("Received unexpected unsubscribe message: {}", topic) 194 | } 195 | } 196 | b"punsubscribe" => { 197 | if self.psubscriptions.remove(&topic).is_none() { 198 | log::warn!("Received unexpected unsubscribe message: {}", topic) 199 | } 200 | } 201 | b"message" => match self.subscriptions.get(&topic) { 202 | Some(sender) => { 203 | if let Err(error) = sender.unbounded_send(Ok(msg)) { 204 | if !error.is_disconnected() { 205 | return Err(error::internal(format!("Cannot send message: {}", error))); 206 | } 207 | } 208 | } 209 | None => { 210 | return Err(error::internal(format!( 211 | "Unexpected message on topic: {}", 212 | topic 213 | ))); 214 | } 215 | }, 216 | b"pmessage" => match self.psubscriptions.get(&topic) { 217 | Some(sender) => { 218 | if let Err(error) = sender.unbounded_send(Ok(msg)) { 219 | if !error.is_disconnected() { 220 | return Err(error::internal(format!("Cannot send message: {}", error))); 221 | } 222 | } 223 | } 224 | None => { 225 | return Err(error::internal(format!( 226 | "Unexpected message on topic: {}", 227 | topic 228 | ))); 229 | } 230 | }, 231 | t => { 232 | return Err(error::internal(format!( 233 | "Unexpected data on Pub/Sub connection: {}", 234 | String::from_utf8_lossy(t) 235 | ))); 236 | } 237 | } 238 | 239 | Ok(()) 240 | } 241 | 242 | /// Checks whether the conditions are met such that this task should end. 243 | /// The task should end when all three conditions are true: 244 | /// 1. There are no active or pending subscriptions 245 | /// 2. There are no active or pending psubscriptions 246 | /// 3. The channel where new subscriptions come from is closed 247 | fn should_end(&self) -> bool { 248 | self.subscriptions.is_empty() 249 | && self.psubscriptions.is_empty() 250 | && self.pending_subs.is_empty() 251 | && self.pending_psubs.is_empty() 252 | && self.out_rx.is_done() 253 | } 254 | 255 | /// Returns true, if there are still valid subscriptions at the end, or false if not, i.e. the whole thing can be dropped. 256 | fn handle_messages(&mut self, cx: &mut Context) -> Result { 257 | loop { 258 | match self.connection.poll_next_unpin(cx) { 259 | Poll::Pending => { 260 | // Nothing to do, so lets carry on 261 | return Ok(true); 262 | } 263 | Poll::Ready(None) => { 264 | // The Redis connection has closed, so we either: 265 | if self.subscriptions.is_empty() && self.psubscriptions.is_empty() { 266 | // There are no subscriptions, so we stop without failure. 267 | return Ok(false); 268 | } else { 269 | // There are active subscriptions, so we send an error to each of them 270 | // to let them know that the connection has closed. 271 | return Err(error::Error::Connection(ConnectionReason::NotConnected)); 272 | } 273 | } 274 | Poll::Ready(Some(Ok(message))) => { 275 | // A valid has message has been received, so lets handle it... 276 | self.handle_message(message)?; 277 | 278 | // After handling a message, there may no longer be any valid subscriptions, so we check 279 | // all the ending criteria 280 | if self.should_end() { 281 | return Ok(false); 282 | } 283 | } 284 | Poll::Ready(Some(Err(e))) => { 285 | // An error occurred from the Redis connection, so we send an error to each of the 286 | // subscriptions to let them know that the connection has errored. 287 | return Err(e); 288 | } 289 | } 290 | } 291 | } 292 | } 293 | 294 | impl Future for PubsubConnectionInner { 295 | type Output = Result<(), error::Error>; 296 | 297 | fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { 298 | let this_self = self.get_mut(); 299 | 300 | // Check the incoming channel for new subscriptions 301 | if let Err(e) = this_self.handle_new_subs(cx) { 302 | return Poll::Ready(this_self.fail_all(e)); 303 | } 304 | 305 | if this_self.should_end() { 306 | // There are no current subscriptions, and the channel via which new subscriptions 307 | // arrive has closed, so this can now end. 308 | return Poll::Ready(Ok(())); 309 | } 310 | 311 | // The following is only valid if the result to `should_end` is false. 312 | if let Err(e) = this_self.do_flush(cx) { 313 | return Poll::Ready(this_self.fail_all(e)); 314 | } 315 | 316 | let cont = match this_self.handle_messages(cx) { 317 | Ok(cont) => cont, 318 | Err(e) => return Poll::Ready(this_self.fail_all(e)), 319 | }; 320 | 321 | if cont { 322 | Poll::Pending 323 | } else { 324 | Poll::Ready(Ok(())) 325 | } 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /src/client/pubsub/mod.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | mod inner; 12 | 13 | use std::future::Future; 14 | use std::pin::Pin; 15 | use std::sync::Arc; 16 | use std::task::{Context, Poll}; 17 | use std::time::Duration; 18 | 19 | use futures_channel::{mpsc, oneshot}; 20 | use futures_util::{ 21 | future::TryFutureExt, 22 | stream::{Stream, StreamExt}, 23 | }; 24 | 25 | use super::{connect::connect_with_auth, ConnectionBuilder}; 26 | 27 | use crate::{ 28 | error, 29 | reconnect::{reconnect, Reconnect}, 30 | resp, 31 | }; 32 | 33 | use self::inner::PubsubConnectionInner; 34 | 35 | #[derive(Debug)] 36 | pub(crate) enum PubsubEvent { 37 | /// The: topic, sink to send messages through, and a oneshot to signal subscription has 38 | /// occurred. 39 | Subscribe(String, PubsubSink, oneshot::Sender<()>), 40 | Psubscribe(String, PubsubSink, oneshot::Sender<()>), 41 | /// The name of the topic to unsubscribe from. Unsubscription will be signaled by the stream 42 | /// closing without error. 43 | Unsubscribe(String), 44 | Punsubscribe(String), 45 | } 46 | 47 | type PubsubStreamInner = mpsc::UnboundedReceiver>; 48 | type PubsubSink = mpsc::UnboundedSender>; 49 | 50 | /// A shareable reference to subscribe to PUBSUB topics 51 | #[derive(Debug, Clone)] 52 | pub struct PubsubConnection { 53 | out_tx_c: Arc>>, 54 | } 55 | 56 | async fn inner_conn_fn( 57 | // Needs to be a String for lifetime reasons 58 | host: String, 59 | port: u16, 60 | username: Option>, 61 | password: Option>, 62 | tls: bool, 63 | socket_keepalive: Option, 64 | socket_timeout: Option, 65 | ) -> Result, error::Error> { 66 | let username = username.as_deref(); 67 | let password = password.as_deref(); 68 | 69 | let connection = connect_with_auth( 70 | &host, 71 | port, 72 | username, 73 | password, 74 | tls, 75 | socket_keepalive, 76 | socket_timeout, 77 | ) 78 | .await?; 79 | let (out_tx, out_rx) = mpsc::unbounded(); 80 | tokio::spawn(async { 81 | match PubsubConnectionInner::new(connection, out_rx).await { 82 | Ok(_) => (), 83 | Err(e) => log::error!("Pub/Sub error: {:?}", e), 84 | } 85 | }); 86 | Ok(out_tx) 87 | } 88 | 89 | impl ConnectionBuilder { 90 | pub fn pubsub_connect(&self) -> impl Future> { 91 | let username = self.username.clone(); 92 | let password = self.password.clone(); 93 | 94 | #[cfg(feature = "tls")] 95 | let tls = self.tls; 96 | #[cfg(not(feature = "tls"))] 97 | let tls = false; 98 | 99 | let host = self.host.clone(); 100 | let port = self.port; 101 | 102 | let socket_keepalive = self.socket_keepalive; 103 | let socket_timeout = self.socket_timeout; 104 | 105 | let reconnecting_f = reconnect( 106 | |con: &mpsc::UnboundedSender, act| { 107 | con.unbounded_send(act).map_err(|e| e.into()) 108 | }, 109 | move || { 110 | let con_f = inner_conn_fn( 111 | host.clone(), 112 | port, 113 | username.clone(), 114 | password.clone(), 115 | tls, 116 | socket_keepalive, 117 | socket_timeout, 118 | ); 119 | Box::pin(con_f) 120 | }, 121 | ); 122 | reconnecting_f.map_ok(|con| PubsubConnection { 123 | out_tx_c: Arc::new(con), 124 | }) 125 | } 126 | } 127 | 128 | /// Used for Redis's PUBSUB functionality. 129 | /// 130 | /// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the 131 | /// connection is established; after the intial establishment, if the connection drops for any 132 | /// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however 133 | /// any subscriptions will need to be re-subscribed. 134 | pub async fn pubsub_connect( 135 | host: impl Into, 136 | port: u16, 137 | ) -> Result { 138 | ConnectionBuilder::new(host, port)?.pubsub_connect().await 139 | } 140 | 141 | impl PubsubConnection { 142 | /// Subscribes to a particular PUBSUB topic. 143 | /// 144 | /// Returns a future that resolves to a `Stream` that contains all the messages published on 145 | /// that particular topic. 146 | /// 147 | /// The resolved stream will end with `redis_async::error::Error::EndOfStream` if the 148 | /// underlying connection is lost for unexpected reasons. In this situation, clients should 149 | /// `subscribe` to re-subscribe; the underlying connect will automatically reconnect. However, 150 | /// clients should be aware that resubscriptions will only succeed if the underlying connection 151 | /// has re-established, so multiple calls to `subscribe` may be required. 152 | pub async fn subscribe(&self, topic: &str) -> Result { 153 | let (tx, rx) = mpsc::unbounded(); 154 | let (signal_t, signal_r) = oneshot::channel(); 155 | self.out_tx_c 156 | .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?; 157 | 158 | match signal_r.await { 159 | Ok(_) => Ok(PubsubStream { 160 | topic: topic.to_owned(), 161 | underlying: rx, 162 | con: self.clone(), 163 | is_pattern: false, 164 | }), 165 | Err(_) => Err(error::internal("Subscription failed, try again later...")), 166 | } 167 | } 168 | 169 | pub async fn psubscribe(&self, topic: &str) -> Result { 170 | let (tx, rx) = mpsc::unbounded(); 171 | let (signal_t, signal_r) = oneshot::channel(); 172 | self.out_tx_c 173 | .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?; 174 | 175 | match signal_r.await { 176 | Ok(_) => Ok(PubsubStream { 177 | topic: topic.to_owned(), 178 | underlying: rx, 179 | con: self.clone(), 180 | is_pattern: true, 181 | }), 182 | Err(_) => Err(error::internal("Subscription failed, try again later...")), 183 | } 184 | } 185 | 186 | /// Tells the client to unsubscribe from a particular topic. This will return immediately, the 187 | /// actual unsubscription will be confirmed when the stream returned from `subscribe` ends. 188 | pub fn unsubscribe>(&self, topic: T) { 189 | // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe 190 | // anyway, and would be reported/logged elsewhere 191 | let _ = self 192 | .out_tx_c 193 | .do_work(PubsubEvent::Unsubscribe(topic.into())); 194 | } 195 | 196 | pub fn punsubscribe>(&self, topic: T) { 197 | // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe 198 | // anyway, and would be reported/logged elsewhere 199 | let _ = self 200 | .out_tx_c 201 | .do_work(PubsubEvent::Punsubscribe(topic.into())); 202 | } 203 | } 204 | 205 | #[derive(Debug)] 206 | pub struct PubsubStream { 207 | topic: String, 208 | underlying: PubsubStreamInner, 209 | con: PubsubConnection, 210 | is_pattern: bool, 211 | } 212 | 213 | impl Stream for PubsubStream { 214 | type Item = Result; 215 | 216 | #[inline] 217 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 218 | self.get_mut().underlying.poll_next_unpin(cx) 219 | } 220 | } 221 | 222 | impl Drop for PubsubStream { 223 | fn drop(&mut self) { 224 | let topic: &str = self.topic.as_ref(); 225 | if self.is_pattern { 226 | self.con.punsubscribe(topic); 227 | } else { 228 | self.con.unsubscribe(topic); 229 | } 230 | } 231 | } 232 | 233 | #[cfg(test)] 234 | mod test { 235 | use std::mem; 236 | 237 | use futures::{try_join, StreamExt, TryStreamExt}; 238 | 239 | use crate::{client, resp}; 240 | 241 | /* IMPORTANT: The tests run in parallel, so the topic names used must be exclusive to each test */ 242 | static SUBSCRIBE_TEST_TOPIC: &str = "test-topic"; 243 | static SUBSCRIBE_TEST_NON_TOPIC: &str = "test-not-topic"; 244 | 245 | static UNSUBSCRIBE_TOPIC_1: &str = "test-topic-1"; 246 | static UNSUBSCRIBE_TOPIC_2: &str = "test-topic-2"; 247 | static UNSUBSCRIBE_TOPIC_3: &str = "test-topic-3"; 248 | 249 | static RESUBSCRIBE_TOPIC: &str = "test-topic-resubscribe"; 250 | 251 | static DROP_CONNECTION_TOPIC: &str = "test-topic-drop-connection"; 252 | 253 | static PSUBSCRIBE_PATTERN: &str = "ptest.*"; 254 | static PSUBSCRIBE_TOPIC_1: &str = "ptest.1"; 255 | static PSUBSCRIBE_TOPIC_2: &str = "ptest.2"; 256 | static PSUBSCRIBE_TOPIC_3: &str = "ptest.3"; 257 | 258 | static UNSUBSCRIBE_TWICE_TOPIC_1: &str = "test-topic-1-twice"; 259 | static UNSUBSCRIBE_TWICE_TOPIC_2: &str = "test-topic-2-twice"; 260 | 261 | #[tokio::test] 262 | async fn subscribe_test() { 263 | let paired_c = client::paired_connect("127.0.0.1", 6379); 264 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 265 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 266 | 267 | let topic_messages = pubsub 268 | .subscribe(SUBSCRIBE_TEST_TOPIC) 269 | .await 270 | .expect("Cannot subscribe to topic"); 271 | 272 | paired.send_and_forget(resp_array!["PUBLISH", SUBSCRIBE_TEST_TOPIC, "test-message"]); 273 | paired.send_and_forget(resp_array![ 274 | "PUBLISH", 275 | SUBSCRIBE_TEST_NON_TOPIC, 276 | "test-message-1.5" 277 | ]); 278 | let _: resp::RespValue = paired 279 | .send(resp_array![ 280 | "PUBLISH", 281 | SUBSCRIBE_TEST_TOPIC, 282 | "test-message2" 283 | ]) 284 | .await 285 | .expect("Cannot send to topic"); 286 | 287 | let result: Vec<_> = topic_messages 288 | .take(2) 289 | .try_collect() 290 | .await 291 | .expect("Cannot collect two values"); 292 | 293 | assert_eq!(result.len(), 2); 294 | assert_eq!(result[0], "test-message".into()); 295 | assert_eq!(result[1], "test-message2".into()); 296 | } 297 | 298 | /// A test to examine the edge-case where a client subscribes to a topic, then the subscription is specifically unsubscribed, 299 | /// vs. where the subscription is automatically unsubscribed. 300 | #[tokio::test] 301 | async fn unsubscribe_test() { 302 | let paired_c = client::paired_connect("127.0.0.1", 6379); 303 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 304 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 305 | 306 | let mut topic_1 = pubsub 307 | .subscribe(UNSUBSCRIBE_TOPIC_1) 308 | .await 309 | .expect("Cannot subscribe to topic"); 310 | let mut topic_2 = pubsub 311 | .subscribe(UNSUBSCRIBE_TOPIC_2) 312 | .await 313 | .expect("Cannot subscribe to topic"); 314 | let mut topic_3 = pubsub 315 | .subscribe(UNSUBSCRIBE_TOPIC_3) 316 | .await 317 | .expect("Cannot subscribe to topic"); 318 | 319 | paired.send_and_forget(resp_array![ 320 | "PUBLISH", 321 | UNSUBSCRIBE_TOPIC_1, 322 | "test-message-1" 323 | ]); 324 | paired.send_and_forget(resp_array![ 325 | "PUBLISH", 326 | UNSUBSCRIBE_TOPIC_2, 327 | "test-message-2" 328 | ]); 329 | paired.send_and_forget(resp_array![ 330 | "PUBLISH", 331 | UNSUBSCRIBE_TOPIC_3, 332 | "test-message-3" 333 | ]); 334 | 335 | let result1 = topic_1 336 | .next() 337 | .await 338 | .expect("Cannot get next value") 339 | .expect("Cannot get next value"); 340 | assert_eq!(result1, "test-message-1".into()); 341 | 342 | let result2 = topic_2 343 | .next() 344 | .await 345 | .expect("Cannot get next value") 346 | .expect("Cannot get next value"); 347 | assert_eq!(result2, "test-message-2".into()); 348 | 349 | let result3 = topic_3 350 | .next() 351 | .await 352 | .expect("Cannot get next value") 353 | .expect("Cannot get next value"); 354 | assert_eq!(result3, "test-message-3".into()); 355 | 356 | // Unsubscribe from topic 2 357 | pubsub.unsubscribe(UNSUBSCRIBE_TOPIC_2); 358 | 359 | // Drop the subscription for topic 3 360 | mem::drop(topic_3); 361 | 362 | // Send some more messages 363 | paired.send_and_forget(resp_array![ 364 | "PUBLISH", 365 | UNSUBSCRIBE_TOPIC_1, 366 | "test-message-1.5" 367 | ]); 368 | paired.send_and_forget(resp_array![ 369 | "PUBLISH", 370 | UNSUBSCRIBE_TOPIC_2, 371 | "test-message-2.5" 372 | ]); 373 | paired.send_and_forget(resp_array![ 374 | "PUBLISH", 375 | UNSUBSCRIBE_TOPIC_3, 376 | "test-message-3.5" 377 | ]); 378 | 379 | // Get the next message for topic 1 380 | let result1 = topic_1 381 | .next() 382 | .await 383 | .expect("Cannot get next value") 384 | .expect("Cannot get next value"); 385 | assert_eq!(result1, "test-message-1.5".into()); 386 | 387 | // Get the next message for topic 2 388 | let result2 = topic_2.next().await; 389 | assert!(result2.is_none()); 390 | } 391 | 392 | /// Test that we can subscribe, unsubscribe, and resubscribe to a topic. 393 | #[tokio::test] 394 | async fn resubscribe_test() { 395 | let paired_c = client::paired_connect("127.0.0.1", 6379); 396 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 397 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 398 | 399 | let mut topic_1 = pubsub 400 | .subscribe(RESUBSCRIBE_TOPIC) 401 | .await 402 | .expect("Cannot subscribe to topic"); 403 | 404 | paired.send_and_forget(resp_array!["PUBLISH", RESUBSCRIBE_TOPIC, "test-message-1"]); 405 | 406 | let result1 = topic_1 407 | .next() 408 | .await 409 | .expect("Cannot get next value") 410 | .expect("Cannot get next value"); 411 | assert_eq!(result1, "test-message-1".into()); 412 | 413 | // Unsubscribe from topic 1 414 | pubsub.unsubscribe(RESUBSCRIBE_TOPIC); 415 | 416 | // Send some more messages 417 | paired.send_and_forget(resp_array![ 418 | "PUBLISH", 419 | RESUBSCRIBE_TOPIC, 420 | "test-message-1.5" 421 | ]); 422 | 423 | // Get the next message for topic 1 424 | let result1 = topic_1.next().await; 425 | assert!(result1.is_none()); 426 | 427 | // Resubscribe to topic 1 428 | let mut topic_1 = pubsub 429 | .subscribe(RESUBSCRIBE_TOPIC) 430 | .await 431 | .expect("Cannot subscribe to topic"); 432 | 433 | // Send some more messages 434 | paired.send_and_forget(resp_array![ 435 | "PUBLISH", 436 | RESUBSCRIBE_TOPIC, 437 | "test-message-1.75" 438 | ]); 439 | 440 | // Get the next message for topic 1 441 | let result1 = topic_1 442 | .next() 443 | .await 444 | .expect("Cannot get next value") 445 | .expect("Cannot get next value"); 446 | assert_eq!(result1, "test-message-1.75".into()); 447 | } 448 | 449 | /// Test that dropping the connection doesn't stop the subscriptions. Not initially anyway. 450 | #[tokio::test] 451 | async fn drop_connection_test() { 452 | let paired_c = client::paired_connect("127.0.0.1", 6379); 453 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 454 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 455 | 456 | let mut topic_1 = pubsub 457 | .subscribe(DROP_CONNECTION_TOPIC) 458 | .await 459 | .expect("Cannot subscribe to topic"); 460 | 461 | mem::drop(pubsub); 462 | 463 | paired.send_and_forget(resp_array![ 464 | "PUBLISH", 465 | DROP_CONNECTION_TOPIC, 466 | "test-message-1" 467 | ]); 468 | 469 | let result1 = topic_1 470 | .next() 471 | .await 472 | .expect("Cannot get next value") 473 | .expect("Cannot get next value"); 474 | assert_eq!(result1, "test-message-1".into()); 475 | 476 | mem::drop(topic_1); 477 | } 478 | 479 | #[tokio::test] 480 | async fn psubscribe_test() { 481 | let paired_c = client::paired_connect("127.0.0.1", 6379); 482 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 483 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 484 | 485 | let topic_messages = pubsub 486 | .psubscribe(PSUBSCRIBE_PATTERN) 487 | .await 488 | .expect("Cannot subscribe to topic"); 489 | 490 | paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_1, "test-message-1"]); 491 | paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_2, "test-message-2"]); 492 | let _: resp::RespValue = paired 493 | .send(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_3, "test-message-3"]) 494 | .await 495 | .expect("Cannot send to topic"); 496 | 497 | let result: Vec<_> = topic_messages 498 | .take(3) 499 | .try_collect() 500 | .await 501 | .expect("Cannot collect two values"); 502 | 503 | assert_eq!(result.len(), 3); 504 | assert_eq!(result[0], "test-message-1".into()); 505 | assert_eq!(result[1], "test-message-2".into()); 506 | assert_eq!(result[2], "test-message-3".into()); 507 | } 508 | 509 | /// Allow unsubscribe to be called twice 510 | #[tokio::test] 511 | async fn unsubscribe_twice_test() { 512 | let paired_c = client::paired_connect("127.0.0.1", 6379); 513 | let pubsub_c = super::pubsub_connect("127.0.0.1", 6379); 514 | let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); 515 | 516 | let mut topic_1 = pubsub 517 | .subscribe(UNSUBSCRIBE_TWICE_TOPIC_1) 518 | .await 519 | .expect("Cannot subscribe to topic"); 520 | let mut topic_2 = pubsub 521 | .subscribe(UNSUBSCRIBE_TWICE_TOPIC_2) 522 | .await 523 | .expect("Cannot subscribe to topic"); 524 | 525 | paired.send_and_forget(resp_array![ 526 | "PUBLISH", 527 | UNSUBSCRIBE_TWICE_TOPIC_1, 528 | "test-message-1" 529 | ]); 530 | paired.send_and_forget(resp_array![ 531 | "PUBLISH", 532 | UNSUBSCRIBE_TWICE_TOPIC_2, 533 | "test-message-2" 534 | ]); 535 | 536 | pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2); 537 | pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2); 538 | 539 | paired.send_and_forget(resp_array![ 540 | "PUBLISH", 541 | UNSUBSCRIBE_TWICE_TOPIC_1, 542 | "test-message-1.5" 543 | ]); 544 | 545 | pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_1); 546 | 547 | let result1 = topic_1 548 | .next() 549 | .await 550 | .expect("Cannot get next value") 551 | .expect("Cannot get next value"); 552 | assert_eq!(result1, "test-message-1".into()); 553 | 554 | let result1 = topic_1 555 | .next() 556 | .await 557 | .expect("Cannot get next value") 558 | .expect("Cannot get next value"); 559 | assert_eq!(result1, "test-message-1.5".into()); 560 | 561 | let result2 = topic_2 562 | .next() 563 | .await 564 | .expect("Cannot get next value") 565 | .expect("Cannot get next value"); 566 | assert_eq!(result2, "test-message-2".into()); 567 | 568 | let result1 = topic_1.next().await; 569 | assert!(result1.is_none()); 570 | 571 | let result2 = topic_2.next().await; 572 | assert!(result2.is_none()); 573 | } 574 | } 575 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2023 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | //! Error handling 12 | 13 | use std::error; 14 | use std::fmt; 15 | use std::io; 16 | use std::sync::Arc; 17 | 18 | use futures_channel::mpsc; 19 | 20 | use crate::resp; 21 | 22 | #[derive(Debug, Clone)] 23 | pub enum Error { 24 | /// A non-specific internal error that prevented an operation from completing 25 | Internal(String), 26 | 27 | /// An IO error occurred 28 | IO(Arc), 29 | 30 | /// A RESP parsing/serialising error occurred 31 | Resp(String, Option), 32 | 33 | /// A remote error 34 | Remote(String), 35 | 36 | /// Error creating a connection, or an error with a connection being closed unexpectedly 37 | Connection(ConnectionReason), 38 | 39 | /// An unexpected error. In this context "unexpected" means 40 | /// "unexpected because we check ahead of time", it used to maintain the type signature of 41 | /// chains of futures; but it occurring at runtime should be considered a catastrophic 42 | /// failure. 43 | /// 44 | /// If any error is propagated this way that needs to be handled, then it should be made into 45 | /// a proper option. 46 | Unexpected(String), 47 | 48 | #[cfg(feature = "with-rustls")] 49 | InvalidDnsName, 50 | 51 | #[cfg(feature = "with-native-tls")] 52 | Tls(Arc), 53 | } 54 | 55 | pub(crate) fn internal(msg: impl Into) -> Error { 56 | Error::Internal(msg.into()) 57 | } 58 | 59 | pub(crate) fn unexpected(msg: impl Into) -> Error { 60 | Error::Unexpected(msg.into()) 61 | } 62 | 63 | pub(crate) fn resp(msg: impl Into, resp: resp::RespValue) -> Error { 64 | Error::Resp(msg.into(), Some(resp)) 65 | } 66 | 67 | impl From for Error { 68 | fn from(err: io::Error) -> Error { 69 | Error::IO(Arc::new(err)) 70 | } 71 | } 72 | 73 | impl From> for Error { 74 | fn from(err: mpsc::TrySendError) -> Error { 75 | Error::Unexpected(format!("Cannot write to channel: {}", err)) 76 | } 77 | } 78 | 79 | impl error::Error for Error { 80 | fn source(&self) -> Option<&(dyn error::Error + 'static)> { 81 | match self { 82 | Error::IO(err) => Some(err), 83 | #[cfg(feature = "with-native-tls")] 84 | Error::Tls(err) => Some(err), 85 | _ => None, 86 | } 87 | } 88 | } 89 | 90 | #[cfg(feature = "with-native-tls")] 91 | impl From for Error { 92 | fn from(err: native_tls::Error) -> Error { 93 | Error::Tls(Arc::new(err)) 94 | } 95 | } 96 | 97 | impl fmt::Display for Error { 98 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 99 | match self { 100 | Error::Internal(s) => write!(f, "{}", s), 101 | Error::IO(err) => write!(f, "{}", err), 102 | Error::Resp(s, resp) => write!(f, "{}: {:?}", s, resp), 103 | Error::Remote(s) => write!(f, "{}", s), 104 | Error::Connection(ConnectionReason::Connected) => { 105 | write!(f, "Connection already established") 106 | } 107 | Error::Connection(ConnectionReason::Connecting) => write!(f, "Connection in progress"), 108 | Error::Connection(ConnectionReason::ConnectionFailed) => { 109 | write!(f, "The last attempt to establish a connection failed") 110 | } 111 | Error::Connection(ConnectionReason::NotConnected) => { 112 | write!(f, "Connection has been closed") 113 | } 114 | #[cfg(feature = "with-rustls")] 115 | Error::InvalidDnsName => { 116 | write!(f, "Invalid dns name") 117 | } 118 | #[cfg(feature = "with-native-tls")] 119 | Error::Tls(err) => write!(f, "{}", err), 120 | Error::Unexpected(err) => write!(f, "{}", err), 121 | } 122 | } 123 | } 124 | 125 | /// Details of a `ConnectionError` 126 | #[derive(Debug, Copy, Clone)] 127 | pub enum ConnectionReason { 128 | /// An attempt to use a connection while it is in the "connecting" state, clients should try 129 | /// again 130 | Connecting, 131 | /// An attempt was made to reconnect after a connection was established, clients should try 132 | /// again 133 | Connected, 134 | /// Connection failed - this can be returned from a call to reconnect, the actual error will be 135 | /// sent to the client at the next call 136 | ConnectionFailed, 137 | /// The connection is not currently connected, the connection will reconnect asynchronously, 138 | /// clients should try again 139 | NotConnected, 140 | } 141 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2022 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | //! A client for Redis using Tokio and Futures. 12 | //! 13 | //! Three interfaces are provided: one low-level, that makes no assumptions about how Redis is used; a high-level client, 14 | //! suitable for the vast majority of use-cases; a PUBSUB client specifically for Redis's PUBSUB functionality. 15 | //! 16 | //! ## Low-level 17 | //! 18 | //! [`client::connect`](client/connect/fn.connect.html) returns a pair of `Sink` and `Stream` (see [futures](https://github.com/alexcrichton/futures-rs)) which 19 | //! both transport [`resp::RespValue`](resp/enum.RespValue.html)s between client and Redis, these work independently of one another 20 | //! to allow pipelining. It is the responsibility of the caller to match responses to requests. It is also the 21 | //! responsibility of the client to convert application data into instances of [`resp::RespValue`](resp/enum.RespValue.html) and 22 | //! back (there are conversion traits available for common examples). 23 | //! 24 | //! This is a very low-level API compared to most Redis clients, but is done so intentionally, for two reasons: 1) it is 25 | //! the common demoniator between a functional Redis client (i.e. is able to support all types of requests, including those 26 | //! that block and have streaming responses), and 2) it results in clean `Sink`s and `Stream`s which will be composable 27 | //! with other Tokio-based libraries. 28 | //! 29 | //! For most practical purposes this low-level interface will not be used, the only exception possibly being the 30 | //! [`MONITOR`](https://redis.io/commands/monitor) command. 31 | //! 32 | //! ## High-level 33 | //! 34 | //! [`client::paired_connect`](client/paired/fn.paired_connect.html) is used for most Redis commands (those for which one command 35 | //! returns one response, it's not suitable for PUBSUB, `MONITOR` or other similar commands). It allows a Redis command to 36 | //! be sent and a Future returned for each command. 37 | //! 38 | //! Commands will be sent in the order that [`send`](client/paired/struct.PairedConnection.html#method.send) is called, regardless 39 | //! of how the future is realised. This is to allow us to take advantage of Redis's features by implicitly pipelining 40 | //! commands where appropriate. One side-effect of this is that for many commands, e.g. `SET` we don't need to realise the 41 | //! future at all, it can be assumed to be fire-and-forget; but, the final future of the final command does need to be 42 | //! realised (at least) to ensure that the correct behaviour is observed. 43 | //! 44 | //! ## PUBSUB 45 | //! 46 | //! PUBSUB in Redis works differently. A connection will subscribe to one or more topics, then receive all messages that 47 | //! are published to that topic. As such the single-request/single-response model of 48 | //! [`paired_connect`](client/paired/fn.paired_connect.html) will not work. A specific 49 | //! [`client::pubsub_connect`](client/pubsub/fn.pubsub_connect.html) is provided for this purpose. 50 | //! 51 | //! It returns a future which resolves to a [`PubsubConnection`](client/pubsub/struct.PubsubConnection.html), this provides a 52 | //! [`subscribe`](client/pubsub/struct.PubsubConnection.html#method.subscribe) function that takes a topic as a parameter and 53 | //! returns a future which, once the subscription is confirmed, resolves to a stream that contains all messages published 54 | //! to that topic. 55 | 56 | #[macro_use] 57 | pub mod resp; 58 | 59 | #[macro_use] 60 | pub mod client; 61 | 62 | pub mod error; 63 | 64 | pub(crate) mod reconnect; 65 | 66 | // Ensure that exclusive features cannot be selected together. 67 | #[cfg(all(feature = "with-rustls", feature = "with-native-tls"))] 68 | compile_error!("Only one TLS backend can be selected at a time"); 69 | -------------------------------------------------------------------------------- /src/reconnect.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2020 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | use std::fmt; 12 | use std::future::Future; 13 | use std::mem; 14 | use std::pin::Pin; 15 | use std::sync::{Arc, Mutex, MutexGuard}; 16 | use std::time::Duration; 17 | 18 | use futures_util::{ 19 | future::{self, Either}, 20 | TryFutureExt, 21 | }; 22 | 23 | use tokio::time::timeout; 24 | 25 | use crate::error::{self, ConnectionReason}; 26 | 27 | type WorkFn = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync; 28 | type ConnFn = 29 | dyn Fn() -> Pin> + Send + Sync>> + Send + Sync; 30 | 31 | struct ReconnectInner { 32 | state: Mutex>, 33 | work_fn: Box>, 34 | conn_fn: Box>, 35 | } 36 | 37 | impl fmt::Debug for ReconnectInner { 38 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 39 | let struct_name = format!( 40 | "ReconnectInner<{}, {}>", 41 | std::any::type_name::(), 42 | std::any::type_name::() 43 | ); 44 | 45 | let work_fn_d = format!("@{:p}", self.work_fn.as_ref()); 46 | let conn_fn_d = format!("@{:p}", self.conn_fn.as_ref()); 47 | 48 | f.debug_struct(&struct_name) 49 | .field("state", &self.state) 50 | .field("work_fn", &work_fn_d) 51 | .field("conn_fn", &conn_fn_d) 52 | .finish() 53 | } 54 | } 55 | 56 | #[derive(Debug)] 57 | pub(crate) struct Reconnect(Arc>); 58 | 59 | impl Clone for Reconnect { 60 | fn clone(&self) -> Self { 61 | Reconnect(self.0.clone()) 62 | } 63 | } 64 | 65 | pub(crate) async fn reconnect(w: W, c: C) -> Result, error::Error> 66 | where 67 | A: Send + 'static, 68 | W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static, 69 | C: Fn() -> Pin> + Send + Sync>> 70 | + Send 71 | + Sync 72 | + 'static, 73 | T: Clone + Send + Sync + 'static, 74 | { 75 | let r = Reconnect(Arc::new(ReconnectInner { 76 | state: Mutex::new(ReconnectState::NotConnected), 77 | 78 | work_fn: Box::new(w), 79 | conn_fn: Box::new(c), 80 | })); 81 | let rf = { 82 | let state = r.0.state.lock().expect("Poisoned lock"); 83 | r.reconnect(state) 84 | }; 85 | rf.await?; 86 | Ok(r) 87 | } 88 | 89 | enum ReconnectState { 90 | NotConnected, 91 | Connected(T), 92 | ConnectionFailed(Mutex>), 93 | Connecting, 94 | } 95 | 96 | impl fmt::Debug for ReconnectState { 97 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 98 | write!(f, "ReconnectState::")?; 99 | match self { 100 | NotConnected => write!(f, "NotConnected"), 101 | Connected(_) => write!(f, "Connected"), 102 | ConnectionFailed(_) => write!(f, "ConnectionFailed"), 103 | Connecting => write!(f, "Connecting"), 104 | } 105 | } 106 | } 107 | 108 | use self::ReconnectState::*; 109 | 110 | const CONNECTION_TIMEOUT_SECONDS: u64 = 10; 111 | const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS); 112 | 113 | impl Reconnect 114 | where 115 | A: Send + 'static, 116 | T: Clone + Send + Sync + 'static, 117 | { 118 | fn call_work(&self, t: &T, a: A) -> Result { 119 | if let Err(e) = (self.0.work_fn)(t, a) { 120 | match e { 121 | error::Error::IO(_) | error::Error::Unexpected(_) => { 122 | log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e); 123 | return Ok(false); 124 | } 125 | _ => (), 126 | } 127 | Err(e) 128 | } else { 129 | Ok(true) 130 | } 131 | } 132 | 133 | pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> { 134 | let mut state = self.0.state.lock().expect("Cannot obtain read lock"); 135 | match *state { 136 | NotConnected => { 137 | self.reconnect_spawn(state); 138 | Err(error::Error::Connection(ConnectionReason::NotConnected)) 139 | } 140 | Connected(ref t) => { 141 | let success = self.call_work(t, a)?; 142 | if !success { 143 | *state = NotConnected; 144 | self.reconnect_spawn(state); 145 | } 146 | Ok(()) 147 | } 148 | ConnectionFailed(ref e) => { 149 | let mut lock = e.lock().expect("Poisioned lock"); 150 | let e = match lock.take() { 151 | Some(e) => e, 152 | None => error::Error::Connection(ConnectionReason::NotConnected), 153 | }; 154 | mem::drop(lock); 155 | 156 | *state = NotConnected; 157 | self.reconnect_spawn(state); 158 | Err(e) 159 | } 160 | Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)), 161 | } 162 | } 163 | 164 | /// Returns a future that completes when the connection is established or failed to establish 165 | /// used only for timing. 166 | fn reconnect( 167 | &self, 168 | mut state: MutexGuard>, 169 | ) -> impl Future> + Send { 170 | log::info!("Attempting to reconnect, current state: {:?}", *state); 171 | 172 | match *state { 173 | Connected(_) => { 174 | return Either::Right(future::err(error::Error::Connection( 175 | ConnectionReason::Connected, 176 | ))); 177 | } 178 | Connecting => { 179 | return Either::Right(future::err(error::Error::Connection( 180 | ConnectionReason::Connecting, 181 | ))); 182 | } 183 | NotConnected | ConnectionFailed(_) => (), 184 | } 185 | *state = ReconnectState::Connecting; 186 | 187 | mem::drop(state); 188 | 189 | let reconnect = self.clone(); 190 | 191 | let connection_f = async move { 192 | let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await { 193 | Ok(con_r) => con_r, 194 | Err(_) => Err(error::internal(format!( 195 | "Connection timed-out after {} seconds", 196 | CONNECTION_TIMEOUT_SECONDS 197 | ))), 198 | }; 199 | 200 | let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock"); 201 | 202 | match *state { 203 | NotConnected | Connecting => match connection { 204 | Ok(t) => { 205 | log::info!("Connection established"); 206 | *state = Connected(t); 207 | Ok(()) 208 | } 209 | Err(e) => { 210 | log::error!("Connection cannot be established: {}", e); 211 | *state = ConnectionFailed(Mutex::new(Some(e))); 212 | Err(error::Error::Connection(ConnectionReason::ConnectionFailed)) 213 | } 214 | }, 215 | ConnectionFailed(_) => { 216 | panic!("The connection state wasn't reset before connecting") 217 | } 218 | Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"), 219 | } 220 | }; 221 | 222 | Either::Left(connection_f) 223 | } 224 | 225 | fn reconnect_spawn(&self, state: MutexGuard>) { 226 | let reconnect_f = self 227 | .reconnect(state) 228 | .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e)); 229 | 230 | tokio::spawn(reconnect_f); 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/resp.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017-2024 Ben Ashford 3 | * 4 | * Licensed under the Apache License, Version 2.0 or the MIT license 6 | * , at your 7 | * option. This file may not be copied, modified, or distributed 8 | * except according to those terms. 9 | */ 10 | 11 | //! An implementation of the RESP protocol 12 | 13 | use std::collections::HashMap; 14 | use std::fmt; 15 | use std::hash::{BuildHasher, Hash}; 16 | use std::io; 17 | use std::str; 18 | use std::sync::Arc; 19 | 20 | use bytes::{Buf, BufMut, BytesMut}; 21 | 22 | use tokio_util::codec::{Decoder, Encoder}; 23 | 24 | use super::error::{self, Error}; 25 | 26 | /// A single RESP value, this owns the data that is read/to-be written to Redis. 27 | /// 28 | /// It is cloneable to allow multiple copies to be delivered in certain circumstances, e.g. multiple 29 | /// subscribers to the same topic. 30 | #[derive(Clone, Eq, PartialEq)] 31 | pub enum RespValue { 32 | Nil, 33 | 34 | /// Zero, one or more other `RespValue`s. 35 | Array(Vec), 36 | 37 | /// A bulk string. In Redis terminology a string is a byte-array, so this is stored as a 38 | /// vector of `u8`s to allow clients to interpret the bytes as appropriate. 39 | BulkString(Vec), 40 | 41 | /// An error from the Redis server 42 | Error(String), 43 | 44 | /// Redis documentation defines an integer as being a signed 64-bit integer: 45 | /// https://redis.io/topics/protocol#resp-integers 46 | Integer(i64), 47 | 48 | SimpleString(String), 49 | } 50 | 51 | impl RespValue { 52 | #[inline] 53 | fn into_result(self) -> Result { 54 | match self { 55 | RespValue::Error(string) => Err(Error::Remote(string)), 56 | x => Ok(x), 57 | } 58 | } 59 | 60 | /// Convenience function for building dynamic Redis commands with variable numbers of 61 | /// arguments, e.g. RPUSH 62 | /// 63 | /// This will panic if called for anything other than arrays 64 | pub fn append(mut self, other: impl IntoIterator) -> Self 65 | where 66 | T: Into, 67 | { 68 | match self { 69 | RespValue::Array(ref mut vals) => { 70 | vals.extend(other.into_iter().map(|t| t.into())); 71 | } 72 | _ => panic!("Can only append to arrays"), 73 | } 74 | self 75 | } 76 | 77 | /// Push item to Resp array 78 | /// 79 | /// This will panic if called for anything other than arrays 80 | pub fn push>(&mut self, item: T) { 81 | match self { 82 | RespValue::Array(ref mut vals) => { 83 | vals.push(item.into()); 84 | } 85 | _ => panic!("Can only push to arrays"), 86 | } 87 | } 88 | } 89 | 90 | impl fmt::Debug for RespValue { 91 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 92 | match self { 93 | RespValue::Nil => write!(f, "Nil"), 94 | RespValue::Array(vals) => write!(f, "Array({:?})", vals), 95 | RespValue::BulkString(bytes) => { 96 | // For BulkString we try and be clever and show the utf-8 version 97 | // if it's valid utf-8 98 | if let Ok(string) = str::from_utf8(bytes) { 99 | write!(f, "BulkString({:?})", string) 100 | } else { 101 | write!(f, "BulkString({:?})", bytes) 102 | } 103 | } 104 | RespValue::Error(string) => write!(f, "Error({:?})", string), 105 | RespValue::Integer(int) => write!(f, "Integer({:?})", int), 106 | RespValue::SimpleString(string) => write!(f, "SimpleString({:?})", string), 107 | } 108 | } 109 | } 110 | 111 | /// A trait to be implemented for every time which can be read from a RESP value. 112 | /// 113 | /// Implementing this trait on a type means that type becomes a valid return type for calls such as `send` on 114 | /// `client::PairedConnection` 115 | pub trait FromResp: Sized { 116 | /// Return a `Result` containing either `Self` or `Error`. Errors can occur due to either: a) the particular 117 | /// `RespValue` being incompatible with the required type, or b) a remote Redis error occuring. 118 | #[inline] 119 | fn from_resp(resp: RespValue) -> Result { 120 | Self::from_resp_int(resp.into_result()?) 121 | } 122 | 123 | fn from_resp_int(resp: RespValue) -> Result; 124 | } 125 | 126 | impl FromResp for RespValue { 127 | #[inline] 128 | fn from_resp_int(resp: RespValue) -> Result { 129 | Ok(resp) 130 | } 131 | } 132 | 133 | impl FromResp for String { 134 | #[inline] 135 | fn from_resp_int(resp: RespValue) -> Result { 136 | match resp { 137 | RespValue::BulkString(ref bytes) => Ok(String::from_utf8_lossy(bytes).into_owned()), 138 | RespValue::Integer(i) => Ok(i.to_string()), 139 | RespValue::SimpleString(string) => Ok(string), 140 | _ => Err(error::resp("Cannot convert into a string", resp)), 141 | } 142 | } 143 | } 144 | 145 | impl FromResp for Arc { 146 | #[inline] 147 | fn from_resp_int(resp: RespValue) -> Result, Error> { 148 | match resp { 149 | RespValue::BulkString(ref bytes) => Ok(String::from_utf8_lossy(bytes).into()), 150 | _ => Err(error::resp("Cannot convert into a Arc", resp)), 151 | } 152 | } 153 | } 154 | 155 | impl FromResp for Vec { 156 | #[inline] 157 | fn from_resp_int(resp: RespValue) -> Result, Error> { 158 | match resp { 159 | RespValue::BulkString(bytes) => Ok(bytes), 160 | _ => Err(error::resp("Not a bulk string", resp)), 161 | } 162 | } 163 | } 164 | 165 | impl FromResp for i64 { 166 | #[inline] 167 | fn from_resp_int(resp: RespValue) -> Result { 168 | match resp { 169 | RespValue::Integer(i) => Ok(i), 170 | _ => Err(error::resp("Cannot be converted into an i64", resp)), 171 | } 172 | } 173 | } 174 | 175 | macro_rules! impl_fromresp_integers { 176 | ($($int_ty:ident),* $(,)*) => { 177 | $( 178 | #[allow(clippy::cast_lossless)] 179 | impl FromResp for $int_ty { 180 | #[inline] 181 | fn from_resp_int(resp: RespValue) -> Result { 182 | i64::from_resp_int(resp).and_then(|x| { 183 | // $int_ty::max_value() as i64 > 0 should be optimized out. It tests if 184 | // the target integer type needs an "upper bounds" check 185 | if x < ($int_ty::min_value() as i64) 186 | || ($int_ty::max_value() as i64 > 0 187 | && x > ($int_ty::max_value() as i64)) 188 | { 189 | Err(error::resp( 190 | concat!( 191 | "i64 value cannot be represented as {}", 192 | stringify!($int_ty), 193 | ), 194 | RespValue::Integer(x), 195 | )) 196 | } else { 197 | Ok(x as $int_ty) 198 | } 199 | }) 200 | } 201 | } 202 | )* 203 | }; 204 | } 205 | 206 | impl_fromresp_integers!(isize, usize, i32, u32, u64); 207 | 208 | impl FromResp for bool { 209 | #[inline] 210 | fn from_resp_int(resp: RespValue) -> Result { 211 | i64::from_resp_int(resp).and_then(|x| match x { 212 | 0 => Ok(false), 213 | 1 => Ok(true), 214 | _ => Err(error::resp( 215 | "i64 value cannot be represented as bool", 216 | RespValue::Integer(x), 217 | )), 218 | }) 219 | } 220 | } 221 | 222 | impl FromResp for Option { 223 | #[inline] 224 | fn from_resp_int(resp: RespValue) -> Result, Error> { 225 | match resp { 226 | RespValue::Nil => Ok(None), 227 | x => Ok(Some(T::from_resp_int(x)?)), 228 | } 229 | } 230 | } 231 | 232 | impl FromResp for Vec { 233 | #[inline] 234 | fn from_resp_int(resp: RespValue) -> Result, Error> { 235 | match resp { 236 | RespValue::Array(ary) => { 237 | let mut ar = Vec::with_capacity(ary.len()); 238 | for value in ary { 239 | ar.push(T::from_resp(value)?); 240 | } 241 | Ok(ar) 242 | } 243 | _ => Err(error::resp("Cannot be converted into a vector", resp)), 244 | } 245 | } 246 | } 247 | 248 | impl FromResp for HashMap { 249 | fn from_resp_int(resp: RespValue) -> Result, Error> { 250 | match resp { 251 | RespValue::Array(ary) => { 252 | let mut map = HashMap::with_capacity_and_hasher(ary.len(), S::default()); 253 | let mut items = ary.into_iter(); 254 | 255 | while let Some(k) = items.next() { 256 | let key = K::from_resp(k)?; 257 | let value = T::from_resp(items.next().ok_or_else(|| { 258 | error::resp( 259 | "Cannot convert an odd number of elements into a hashmap", 260 | "".into(), 261 | ) 262 | })?)?; 263 | 264 | map.insert(key, value); 265 | } 266 | 267 | Ok(map) 268 | } 269 | _ => Err(error::resp("Cannot be converted into a hashmap", resp)), 270 | } 271 | } 272 | } 273 | 274 | impl FromResp for () { 275 | #[inline] 276 | fn from_resp_int(resp: RespValue) -> Result<(), Error> { 277 | match resp { 278 | RespValue::SimpleString(string) => match string.as_ref() { 279 | "OK" => Ok(()), 280 | _ => Err(Error::Resp( 281 | format!("Unexpected value within SimpleString: {}", string), 282 | None, 283 | )), 284 | }, 285 | _ => Err(error::resp( 286 | "Unexpected value, should be encoded as a SimpleString", 287 | resp, 288 | )), 289 | } 290 | } 291 | } 292 | 293 | impl FromResp for (A, B) 294 | where 295 | A: FromResp, 296 | B: FromResp, 297 | { 298 | #[inline] 299 | fn from_resp_int(resp: RespValue) -> Result<(A, B), Error> { 300 | match resp { 301 | RespValue::Array(ary) => { 302 | if ary.len() == 2 { 303 | let mut ary_iter = ary.into_iter(); 304 | Ok(( 305 | A::from_resp(ary_iter.next().expect("No value"))?, 306 | B::from_resp(ary_iter.next().expect("No value"))?, 307 | )) 308 | } else { 309 | Err(Error::Resp( 310 | format!("Array needs to be 2 elements, is: {}", ary.len()), 311 | None, 312 | )) 313 | } 314 | } 315 | _ => Err(error::resp( 316 | "Unexpected value, should be encoded as an array", 317 | resp, 318 | )), 319 | } 320 | } 321 | } 322 | 323 | impl FromResp for (A, B, C) 324 | where 325 | A: FromResp, 326 | B: FromResp, 327 | C: FromResp, 328 | { 329 | #[inline] 330 | fn from_resp_int(resp: RespValue) -> Result<(A, B, C), Error> { 331 | match resp { 332 | RespValue::Array(ary) => { 333 | if ary.len() == 3 { 334 | let mut ary_iter = ary.into_iter(); 335 | Ok(( 336 | A::from_resp(ary_iter.next().expect("No value"))?, 337 | B::from_resp(ary_iter.next().expect("No value"))?, 338 | C::from_resp(ary_iter.next().expect("No value"))?, 339 | )) 340 | } else { 341 | Err(Error::Resp( 342 | format!("Array needs to be 3 elements, is: {}", ary.len()), 343 | None, 344 | )) 345 | } 346 | } 347 | _ => Err(error::resp( 348 | "Unexpected value, should be encoded as an array", 349 | resp, 350 | )), 351 | } 352 | } 353 | } 354 | 355 | /// Macro to create a RESP array, useful for preparing commands to send. Elements can be any type, or a mixture 356 | /// of types, that satisfy `Into`. 357 | /// 358 | /// As a general rule, if a value is moved, the data can be deconstructed (if appropriate, e.g. String) and the raw 359 | /// data moved into the corresponding `RespValue`. If a reference is provided, the data will be copied instead. 360 | /// 361 | /// # Examples 362 | /// 363 | /// ``` 364 | /// #[macro_use] 365 | /// extern crate redis_async; 366 | /// 367 | /// fn main() { 368 | /// let value = format!("something_{}", 123); 369 | /// resp_array!["SET", "key_name", value]; 370 | /// } 371 | /// ``` 372 | /// 373 | /// For variable length Redis commands: 374 | /// 375 | /// ``` 376 | /// #[macro_use] 377 | /// extern crate redis_async; 378 | /// 379 | /// fn main() { 380 | /// let data = vec!["data", "from", "somewhere", "else"]; 381 | /// let command = resp_array!["RPUSH", "mykey"].append(data); 382 | /// } 383 | /// ``` 384 | #[macro_export] 385 | macro_rules! resp_array { 386 | ($($e:expr),* $(,)?) => { 387 | { 388 | $crate::resp::RespValue::Array(vec![ 389 | $( 390 | $e.into(), 391 | )* 392 | ]) 393 | } 394 | } 395 | } 396 | 397 | macro_rules! into_resp { 398 | ($t:ty, $f:ident) => { 399 | impl<'a> From<$t> for RespValue { 400 | #[inline] 401 | fn from(from: $t) -> RespValue { 402 | from.$f() 403 | } 404 | } 405 | }; 406 | } 407 | 408 | /// A specific trait to convert into a `RespValue::BulkString` 409 | pub trait IntoRespString { 410 | fn into_resp_string(self) -> RespValue; 411 | } 412 | 413 | macro_rules! string_into_resp { 414 | ($t:ty) => { 415 | into_resp!($t, into_resp_string); 416 | }; 417 | } 418 | 419 | impl IntoRespString for String { 420 | #[inline] 421 | fn into_resp_string(self) -> RespValue { 422 | RespValue::BulkString(self.into_bytes()) 423 | } 424 | } 425 | string_into_resp!(String); 426 | 427 | impl<'a> IntoRespString for &'a String { 428 | #[inline] 429 | fn into_resp_string(self) -> RespValue { 430 | RespValue::BulkString(self.as_bytes().into()) 431 | } 432 | } 433 | string_into_resp!(&'a String); 434 | 435 | impl<'a> IntoRespString for &'a str { 436 | #[inline] 437 | fn into_resp_string(self) -> RespValue { 438 | RespValue::BulkString(self.as_bytes().into()) 439 | } 440 | } 441 | string_into_resp!(&'a str); 442 | 443 | impl<'a> IntoRespString for &'a [u8] { 444 | #[inline] 445 | fn into_resp_string(self) -> RespValue { 446 | RespValue::BulkString(self.to_vec()) 447 | } 448 | } 449 | string_into_resp!(&'a [u8]); 450 | 451 | impl IntoRespString for Vec { 452 | #[inline] 453 | fn into_resp_string(self) -> RespValue { 454 | RespValue::BulkString(self) 455 | } 456 | } 457 | string_into_resp!(Vec); 458 | 459 | impl IntoRespString for Arc { 460 | #[inline] 461 | fn into_resp_string(self) -> RespValue { 462 | RespValue::BulkString(self.as_bytes().into()) 463 | } 464 | } 465 | string_into_resp!(Arc); 466 | 467 | pub trait IntoRespInteger { 468 | fn into_resp_integer(self) -> RespValue; 469 | } 470 | 471 | macro_rules! integer_into_resp { 472 | ($t:ty) => { 473 | into_resp!($t, into_resp_integer); 474 | }; 475 | } 476 | 477 | impl IntoRespInteger for usize { 478 | #[inline] 479 | fn into_resp_integer(self) -> RespValue { 480 | RespValue::Integer(self as i64) 481 | } 482 | } 483 | integer_into_resp!(usize); 484 | 485 | /// Codec to read frames 486 | pub struct RespCodec; 487 | 488 | fn write_rn(buf: &mut BytesMut) { 489 | buf.put_u8(b'\r'); 490 | buf.put_u8(b'\n'); 491 | } 492 | 493 | fn check_and_reserve(buf: &mut BytesMut, amt: usize) { 494 | let remaining_bytes = buf.remaining_mut(); 495 | if remaining_bytes < amt { 496 | buf.reserve(amt); 497 | } 498 | } 499 | 500 | fn write_header(symb: u8, len: i64, buf: &mut BytesMut) { 501 | let len_as_string = len.to_string(); 502 | let len_as_bytes = len_as_string.as_bytes(); 503 | let header_bytes = 1 + len_as_bytes.len() + 2; 504 | check_and_reserve(buf, header_bytes); 505 | buf.put_u8(symb); 506 | buf.extend(len_as_bytes); 507 | write_rn(buf); 508 | } 509 | 510 | fn write_simple_string(symb: u8, string: &str, buf: &mut BytesMut) { 511 | let bytes = string.as_bytes(); 512 | let size = 1 + bytes.len() + 2; 513 | check_and_reserve(buf, size); 514 | buf.put_u8(symb); 515 | buf.extend(bytes); 516 | write_rn(buf); 517 | } 518 | 519 | impl Encoder for RespCodec { 520 | type Error = io::Error; 521 | 522 | fn encode(&mut self, msg: RespValue, buf: &mut BytesMut) -> Result<(), Self::Error> { 523 | match msg { 524 | RespValue::Nil => { 525 | write_header(b'$', -1, buf); 526 | } 527 | RespValue::Array(ary) => { 528 | write_header(b'*', ary.len() as i64, buf); 529 | for v in ary { 530 | self.encode(v, buf)?; 531 | } 532 | } 533 | RespValue::BulkString(bstr) => { 534 | let len = bstr.len(); 535 | write_header(b'$', len as i64, buf); 536 | check_and_reserve(buf, len + 2); 537 | buf.extend(bstr); 538 | write_rn(buf); 539 | } 540 | RespValue::Error(ref string) => { 541 | write_simple_string(b'-', string, buf); 542 | } 543 | RespValue::Integer(val) => { 544 | // Simple integer are just the header 545 | write_header(b':', val, buf); 546 | } 547 | RespValue::SimpleString(ref string) => { 548 | write_simple_string(b'+', string, buf); 549 | } 550 | } 551 | Ok(()) 552 | } 553 | } 554 | 555 | #[inline] 556 | fn parse_error(message: String) -> Error { 557 | Error::Resp(message, None) 558 | } 559 | 560 | /// Many RESP types have their length (which is either bytes or "number of elements", depending on context) 561 | /// encoded as a string, terminated by "\r\n", this looks for them. 562 | /// 563 | /// Only return the string if the whole sequence is complete, including the terminator bytes (but those final 564 | /// two bytes will not be returned) 565 | /// 566 | /// TODO - rename this function potentially, it's used for simple integers too 567 | fn scan_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { 568 | let length = buf.len(); 569 | let mut at_end = false; 570 | let mut pos = idx; 571 | loop { 572 | if length <= pos { 573 | return Ok(None); 574 | } 575 | match (at_end, buf[pos]) { 576 | (true, b'\n') => return Ok(Some((pos + 1, &buf[idx..pos - 1]))), 577 | (false, b'\r') => at_end = true, 578 | (false, b'0'..=b'9') => (), 579 | (false, b'-') => (), 580 | (_, val) => { 581 | return Err(parse_error(format!( 582 | "Unexpected byte in size_string: {}", 583 | val 584 | ))); 585 | } 586 | } 587 | pos += 1; 588 | } 589 | } 590 | 591 | fn scan_string(buf: &mut BytesMut, idx: usize) -> Option<(usize, String)> { 592 | let length = buf.len(); 593 | let mut at_end = false; 594 | let mut pos = idx; 595 | loop { 596 | if length <= pos { 597 | return None; 598 | } 599 | match (at_end, buf[pos]) { 600 | (true, b'\n') => { 601 | let value = String::from_utf8_lossy(&buf[idx..pos - 1]).into_owned(); 602 | return Some((pos + 1, value)); 603 | } 604 | (true, _) => at_end = false, 605 | (false, b'\r') => at_end = true, 606 | (false, _) => (), 607 | } 608 | pos += 1; 609 | } 610 | } 611 | 612 | fn decode_raw_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { 613 | match scan_integer(buf, idx) { 614 | Ok(None) => Ok(None), 615 | Ok(Some((pos, int_str))) => { 616 | // Redis integers are transmitted as strings, so we first convert the raw bytes into a string... 617 | match str::from_utf8(int_str) { 618 | Ok(string) => { 619 | // ...and then parse the string. 620 | match string.parse() { 621 | Ok(int) => Ok(Some((pos, int))), 622 | Err(_) => Err(parse_error(format!("Not an integer: {}", string))), 623 | } 624 | } 625 | Err(_) => Err(parse_error(format!("Not a valid string: {:?}", int_str))), 626 | } 627 | } 628 | Err(e) => Err(e), 629 | } 630 | } 631 | 632 | type DecodeResult = Result, Error>; 633 | 634 | fn decode_bulk_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { 635 | match decode_raw_integer(buf, idx) { 636 | Ok(None) => Ok(None), 637 | Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), 638 | Ok(Some((pos, size))) if size >= 0 => { 639 | let size = size as usize; 640 | let remaining = buf.len() - pos; 641 | let required_bytes = size + 2; 642 | 643 | if remaining < required_bytes { 644 | return Ok(None); 645 | } 646 | 647 | let bulk_string = RespValue::BulkString(buf[pos..(pos + size)].to_vec()); 648 | Ok(Some((pos + required_bytes, bulk_string))) 649 | } 650 | Ok(Some((_, size))) => Err(parse_error(format!("Invalid string size: {}", size))), 651 | Err(e) => Err(e), 652 | } 653 | } 654 | 655 | fn decode_array(buf: &mut BytesMut, idx: usize) -> DecodeResult { 656 | match decode_raw_integer(buf, idx) { 657 | Ok(None) => Ok(None), 658 | Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), 659 | Ok(Some((pos, size))) if size >= 0 => { 660 | let size = size as usize; 661 | let mut pos = pos; 662 | let mut values = Vec::with_capacity(size); 663 | for _ in 0..size { 664 | match decode(buf, pos) { 665 | Ok(None) => return Ok(None), 666 | Ok(Some((new_pos, value))) => { 667 | values.push(value); 668 | pos = new_pos; 669 | } 670 | Err(e) => return Err(e), 671 | } 672 | } 673 | Ok(Some((pos, RespValue::Array(values)))) 674 | } 675 | Ok(Some((_, size))) => Err(parse_error(format!("Invalid array size: {}", size))), 676 | Err(e) => Err(e), 677 | } 678 | } 679 | 680 | fn decode_integer(buf: &mut BytesMut, idx: usize) -> DecodeResult { 681 | match decode_raw_integer(buf, idx) { 682 | Ok(None) => Ok(None), 683 | Ok(Some((pos, int))) => Ok(Some((pos, RespValue::Integer(int)))), 684 | Err(e) => Err(e), 685 | } 686 | } 687 | 688 | /// A simple string is any series of bytes that ends with `\r\n` 689 | #[allow(clippy::unnecessary_wraps)] 690 | fn decode_simple_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { 691 | match scan_string(buf, idx) { 692 | None => Ok(None), 693 | Some((pos, string)) => Ok(Some((pos, RespValue::SimpleString(string)))), 694 | } 695 | } 696 | 697 | #[allow(clippy::unnecessary_wraps)] 698 | fn decode_error(buf: &mut BytesMut, idx: usize) -> DecodeResult { 699 | match scan_string(buf, idx) { 700 | None => Ok(None), 701 | Some((pos, string)) => Ok(Some((pos, RespValue::Error(string)))), 702 | } 703 | } 704 | 705 | fn decode(buf: &mut BytesMut, idx: usize) -> DecodeResult { 706 | let length = buf.len(); 707 | if length <= idx { 708 | return Ok(None); 709 | } 710 | 711 | let first_byte = buf[idx]; 712 | match first_byte { 713 | b'$' => decode_bulk_string(buf, idx + 1), 714 | b'*' => decode_array(buf, idx + 1), 715 | b':' => decode_integer(buf, idx + 1), 716 | b'+' => decode_simple_string(buf, idx + 1), 717 | b'-' => decode_error(buf, idx + 1), 718 | _ => Err(parse_error(format!("Unexpected byte: {}", first_byte))), 719 | } 720 | } 721 | 722 | impl Decoder for RespCodec { 723 | type Item = RespValue; 724 | type Error = Error; 725 | 726 | fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { 727 | match decode(buf, 0) { 728 | Ok(None) => Ok(None), 729 | Ok(Some((pos, item))) => { 730 | buf.advance(pos); 731 | Ok(Some(item)) 732 | } 733 | Err(e) => Err(e), 734 | } 735 | } 736 | } 737 | 738 | #[cfg(test)] 739 | mod tests { 740 | use std::collections::HashMap; 741 | 742 | use bytes::BytesMut; 743 | 744 | use tokio_util::codec::{Decoder, Encoder}; 745 | 746 | use super::{Error, FromResp, RespCodec, RespValue}; 747 | 748 | fn obj_to_bytes(obj: RespValue) -> Vec { 749 | let mut bytes = BytesMut::new(); 750 | let mut codec = RespCodec; 751 | codec.encode(obj, &mut bytes).unwrap(); 752 | bytes.to_vec() 753 | } 754 | 755 | #[test] 756 | fn test_resp_array_macro() { 757 | let resp_object = resp_array!["SET", "x"]; 758 | let bytes = obj_to_bytes(resp_object); 759 | assert_eq!(b"*2\r\n$3\r\nSET\r\n$1\r\nx\r\n", bytes.as_slice()); 760 | 761 | let resp_object = resp_array!["RPUSH", "wyz"].append(vec!["a", "b"]); 762 | let bytes = obj_to_bytes(resp_object); 763 | assert_eq!( 764 | &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nwyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], 765 | bytes.as_slice() 766 | ); 767 | 768 | let vals = vec![String::from("a"), String::from("b")]; 769 | #[allow(clippy::needless_borrow)] 770 | let resp_object = resp_array!["RPUSH", "xyz"].append(vals); 771 | let bytes = obj_to_bytes(resp_object); 772 | assert_eq!( 773 | &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nxyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], 774 | bytes.as_slice() 775 | ); 776 | } 777 | 778 | #[test] 779 | fn test_bulk_string() { 780 | let resp_object = RespValue::BulkString(b"THISISATEST".to_vec()); 781 | let mut bytes = BytesMut::new(); 782 | let mut codec = RespCodec; 783 | codec.encode(resp_object.clone(), &mut bytes).unwrap(); 784 | assert_eq!(b"$11\r\nTHISISATEST\r\n".to_vec(), bytes.to_vec()); 785 | 786 | let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); 787 | assert_eq!(deserialized, resp_object); 788 | } 789 | 790 | #[test] 791 | fn test_array() { 792 | let resp_object = RespValue::Array(vec!["TEST1".into(), "TEST2".into()]); 793 | let mut bytes = BytesMut::new(); 794 | let mut codec = RespCodec; 795 | codec.encode(resp_object.clone(), &mut bytes).unwrap(); 796 | assert_eq!( 797 | b"*2\r\n$5\r\nTEST1\r\n$5\r\nTEST2\r\n".to_vec(), 798 | bytes.to_vec() 799 | ); 800 | 801 | let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); 802 | assert_eq!(deserialized, resp_object); 803 | } 804 | 805 | #[test] 806 | fn test_nil_string() { 807 | let mut bytes = BytesMut::new(); 808 | bytes.extend_from_slice(&b"$-1\r\n"[..]); 809 | 810 | let mut codec = RespCodec; 811 | let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); 812 | assert_eq!(deserialized, RespValue::Nil); 813 | } 814 | 815 | #[test] 816 | fn test_integer_overflow() { 817 | let resp_object = RespValue::Integer(i64::max_value()); 818 | let res = i32::from_resp(resp_object); 819 | assert!(res.is_err()); 820 | } 821 | 822 | #[test] 823 | fn test_integer_underflow() { 824 | let resp_object = RespValue::Integer(-2); 825 | let res = u64::from_resp(resp_object); 826 | assert!(res.is_err()); 827 | } 828 | 829 | #[test] 830 | fn test_integer_convesion() { 831 | let resp_object = RespValue::Integer(50); 832 | assert_eq!(u32::from_resp(resp_object).unwrap(), 50); 833 | } 834 | 835 | #[test] 836 | fn test_hashmap_conversion() { 837 | let mut expected = HashMap::new(); 838 | expected.insert("KEY1".to_string(), "VALUE1".to_string()); 839 | expected.insert("KEY2".to_string(), "VALUE2".to_string()); 840 | 841 | let resp_object = RespValue::Array(vec![ 842 | "KEY1".into(), 843 | "VALUE1".into(), 844 | "KEY2".into(), 845 | "VALUE2".into(), 846 | ]); 847 | assert_eq!( 848 | HashMap::::from_resp(resp_object).unwrap(), 849 | expected 850 | ); 851 | } 852 | 853 | #[test] 854 | fn test_hashmap_conversion_fails_with_odd_length_array() { 855 | let resp_object = RespValue::Array(vec![ 856 | "KEY1".into(), 857 | "VALUE1".into(), 858 | "KEY2".into(), 859 | "VALUE2".into(), 860 | "KEY3".into(), 861 | ]); 862 | let res = HashMap::::from_resp(resp_object); 863 | 864 | match res { 865 | Err(Error::Resp(_, _)) => {} 866 | _ => panic!("Should not be able to convert an odd number of elements to a hashmap"), 867 | } 868 | } 869 | } 870 | --------------------------------------------------------------------------------