├── .github ├── dependabot.yml └── workflows │ └── rust.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── Cargo.toml ├── README.md ├── async-sqlx-session │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── counter │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── regenerate │ ├── Cargo.toml │ └── src │ │ └── main.rs └── signin │ ├── Cargo.toml │ └── src │ └── main.rs ├── rustfmt.toml └── src ├── extractors.rs ├── lib.rs └── session.rs /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "cargo" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | check: 14 | runs-on: ubuntu-latest 15 | 16 | strategy: 17 | matrix: 18 | pwd: 19 | - . 20 | - examples 21 | 22 | steps: 23 | - uses: actions/checkout@master 24 | - uses: actions-rs/toolchain@v1 25 | with: 26 | toolchain: nightly 27 | override: true 28 | profile: minimal 29 | components: clippy, rustfmt 30 | - uses: Swatinem/rust-cache@v2 31 | with: 32 | key: ${{ matrix.pwd }} 33 | workspaces: ${{ matrix.pwd }} 34 | - name: clippy 35 | working-directory: ${{ matrix.pwd }} 36 | run: | 37 | cargo clippy --all --all-targets --all-features 38 | - name: rustfmt 39 | working-directory: ${{ matrix.pwd }} 40 | run: | 41 | cargo fmt --all -- --check 42 | 43 | check-docs: 44 | runs-on: ubuntu-latest 45 | 46 | steps: 47 | - uses: actions/checkout@master 48 | - uses: actions-rs/toolchain@v1 49 | with: 50 | toolchain: stable 51 | override: true 52 | profile: minimal 53 | - uses: Swatinem/rust-cache@v2 54 | - name: cargo doc 55 | env: 56 | RUSTDOCFLAGS: "-D rustdoc::broken-intra-doc-links" 57 | run: cargo doc --all-features --no-deps 58 | 59 | test: 60 | needs: check 61 | 62 | runs-on: ubuntu-latest 63 | 64 | steps: 65 | - uses: actions/checkout@master 66 | - uses: actions-rs/toolchain@v1 67 | with: 68 | toolchain: stable 69 | override: true 70 | profile: minimal 71 | - name: Install Tarpaulin 72 | uses: actions-rs/install@v0.1 73 | with: 74 | crate: cargo-tarpaulin 75 | version: 0.22.0 76 | use-tool-cache: true 77 | - uses: Swatinem/rust-cache@v2 78 | - name: Run tests 79 | uses: actions-rs/cargo@v1 80 | with: 81 | command: test 82 | args: --all --all-features --all-targets 83 | - name: Coverage 84 | run: cargo tarpaulin -o Lcov --output-dir ./coverage 85 | - name: Coveralls 86 | uses: coverallsapp/github-action@master 87 | with: 88 | github-token: ${{ secrets.GITHUB_TOKEN }} 89 | 90 | test-docs: 91 | needs: check 92 | 93 | runs-on: ubuntu-latest 94 | 95 | steps: 96 | - uses: actions/checkout@master 97 | - uses: actions-rs/toolchain@v1 98 | with: 99 | toolchain: stable 100 | override: true 101 | profile: minimal 102 | - uses: Swatinem/rust-cache@v2 103 | - name: Run doc tests 104 | uses: actions-rs/cargo@v1 105 | with: 106 | command: test 107 | args: --all-features --doc 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.rs.bk 2 | **/target 3 | **/Cargo.lock 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.5.0 2 | 3 | **BREAKING CHANGES**: 4 | 5 | - Resist session name fingerprinting [PR #36](https://github.com/maxcountryman/axum-sessions/pull/36) 6 | 7 | **OTHER CHANGES**: 8 | 9 | - Allow setting HttpOnly of cookie [PR #30](https://github.com/maxcountryman/axum-sessions/pull/30) 10 | 11 | # 0.4.1 12 | 13 | - Update axum to v0.6.0 14 | 15 | # 0.4.0 16 | 17 | - Avoid storing cookie when not required [PR #15](https://github.com/maxcountryman/axum-sessions/pull/15) 18 | 19 | # 0.3.2 20 | 21 | - Search every Cookie header for session cookie [PR #14](https://github.com/maxcountryman/axum-sessions/pull/14) 22 | 23 | # 0.3.1 24 | 25 | - Derive `Debug` for `WritableSession` and `ReadableSession` ensuring consistency with `Session` 26 | 27 | # 0.3.0 28 | 29 | - Session regeneration support [PR #6](https://github.com/maxcountryman/axum-sessions/pull/6) 30 | 31 | # 0.2.0 32 | 33 | - On session destroy, unset cookie on client [PR #4](https://github.com/maxcountryman/axum-sessions/pull/4) 34 | 35 | # 0.1.1 36 | 37 | - Handle multiple cookie values [PR #2](https://github.com/maxcountryman/axum-sessions/pull/2) 38 | 39 | # 0.1.0 40 | 41 | - Initial release :tada: 42 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "axum-sessions" 3 | version = "0.6.1" 4 | description = "🥠 Cookie-based sessions for Axum via async-session." 5 | edition = "2021" 6 | homepage = "https://github.com/maxcountryman/axum-sessions" 7 | license = "MIT" 8 | keywords = ["axum", "session", "sessions", "cookie", "async-session"] 9 | categories = ["asynchronous", "network-programming", "web-programming"] 10 | repository = "https://github.com/maxcountryman/axum-sessions" 11 | documentation = "https://docs.rs/axum-sessions" 12 | 13 | [dependencies] 14 | async-session = "3.0.0" 15 | futures = "0.3.21" 16 | http-body = "0.4.5" 17 | tower = "0.4.12" 18 | tracing = "0.1" 19 | 20 | [dependencies.axum] 21 | version = "0.6.0" 22 | features = ["headers"] 23 | 24 | [dependencies.axum-extra] 25 | version = "0.7.1" 26 | features = ["cookie-signed"] 27 | 28 | [dependencies.tokio] 29 | version = "1.20.1" 30 | default-features = false 31 | features = ["sync"] 32 | 33 | [dev-dependencies] 34 | http = "0.2.8" 35 | hyper = "0.14.19" 36 | serde = "1.0.147" 37 | 38 | [dev-dependencies.rand] 39 | version = "0.8.5" 40 | features = ["min_const_gen"] 41 | 42 | [dev-dependencies.tokio] 43 | version = "1.20.1" 44 | default-features = false 45 | features = ["macros", "rt-multi-thread"] 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Max Countryman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > [!IMPORTANT] 2 | > 3 | > # **Migration to `tower-sessions`** 4 | > 5 | > **Development of this crate has moved to [`tower-sessions`](https://github.com/maxcountryman/tower-sessions).** Please consider migrating. 6 | > 7 | > Numerous bugs and a significant design flaw with `axum-sessions` are addressed with `tower-sessions`. 8 | -------------------------------------------------------------------------------- /examples/Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["*"] 3 | exclude = ["target"] 4 | resolver = "2" 5 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains examples showing how to use `axum-sessions`. Each example is setup as its own crate so its dependencies are clear. 4 | 5 | ## Running examples 6 | 7 | Examples are intended to be run from the command line. To do so, change to the `examples/` directory and then use cargo: 8 | 9 | ```sh 10 | $ cargo run -p example-signin 11 | ``` 12 | 13 | Change `signin` to the name of the example you'd like to run. 14 | -------------------------------------------------------------------------------- /examples/async-sqlx-session/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-async-sqlx-session" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | axum = "0.6.0" 9 | axum-sessions = { path = "../../" } 10 | 11 | [dependencies.async-sqlx-session] 12 | version = "0.4.0" 13 | default-features = false 14 | features = ["sqlite"] 15 | 16 | [dependencies.rand] 17 | version = "0.8.5" 18 | features = ["min_const_gen"] 19 | 20 | [dependencies.sqlx] 21 | version = "0.5.13" 22 | default-features = false 23 | features = ["runtime-tokio-rustls", "sqlite"] 24 | 25 | [dependencies.tokio] 26 | version = "1.0" 27 | features = ["full"] 28 | -------------------------------------------------------------------------------- /examples/async-sqlx-session/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Run with 2 | //! 3 | //! ```not_rust 4 | //! cd examples && cargo run -p example-async-sqlx-session 5 | //! ``` 6 | 7 | use async_sqlx_session::SqliteSessionStore; 8 | use axum::{routing::get, Router}; 9 | use axum_sessions::{ 10 | extractors::{ReadableSession, WritableSession}, 11 | SessionLayer, 12 | }; 13 | use rand::Rng; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | let store = SqliteSessionStore::new("sqlite::memory:") 18 | .await 19 | .expect("Could not connect to SQLite."); 20 | store 21 | .migrate() 22 | .await 23 | .expect("Could not migrate session store."); 24 | let secret = rand::thread_rng().gen::<[u8; 128]>(); 25 | let session_layer = SessionLayer::new(store, &secret); 26 | 27 | async fn increment_count_handler(mut session: WritableSession) { 28 | let previous: usize = session.get("counter").unwrap_or_default(); 29 | session 30 | .insert("counter", previous + 1) 31 | .expect("Could not store counter."); 32 | } 33 | 34 | async fn handler(session: ReadableSession) -> String { 35 | format!( 36 | "Counter: {}", 37 | session.get::("counter").unwrap_or_default() 38 | ) 39 | } 40 | 41 | let app = Router::new() 42 | .route("/increment", get(increment_count_handler)) 43 | .route("/", get(handler)) 44 | .layer(session_layer); 45 | 46 | axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 47 | .serve(app.into_make_service()) 48 | .await 49 | .unwrap(); 50 | } 51 | -------------------------------------------------------------------------------- /examples/counter/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-counter" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | axum = "0.6.0" 9 | axum-sessions = { path = "../../" } 10 | 11 | [dependencies.rand] 12 | version = "0.8.5" 13 | features = ["min_const_gen"] 14 | 15 | [dependencies.tokio] 16 | version = "1.0" 17 | features = ["full"] 18 | -------------------------------------------------------------------------------- /examples/counter/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Run with 2 | //! 3 | //! ```not_rust 4 | //! cd examples && cargo run -p example-counter 5 | //! ``` 6 | 7 | use axum::{response::IntoResponse, routing::get, Router}; 8 | use axum_sessions::{ 9 | async_session::MemoryStore, 10 | extractors::{ReadableSession, WritableSession}, 11 | SessionLayer, 12 | }; 13 | use rand::Rng; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | let store = MemoryStore::new(); 18 | let secret = rand::thread_rng().gen::<[u8; 128]>(); 19 | let session_layer = SessionLayer::new(store, &secret).with_secure(false); 20 | 21 | async fn display_handler(session: ReadableSession) -> impl IntoResponse { 22 | let mut count = 0; 23 | count = session.get("count").unwrap_or(count); 24 | format!( 25 | "Count is: {}; visit /inc to increment and /reset to reset", 26 | count 27 | ) 28 | } 29 | 30 | async fn increment_handler(mut session: WritableSession) -> impl IntoResponse { 31 | let mut count = 1; 32 | count = session.get("count").map(|n: i32| n + 1).unwrap_or(count); 33 | session.insert("count", count).unwrap(); 34 | format!("Count is: {}", count) 35 | } 36 | 37 | async fn reset_handler(mut session: WritableSession) -> impl IntoResponse { 38 | session.destroy(); 39 | "Count reset" 40 | } 41 | 42 | let app = Router::new() 43 | .route("/", get(display_handler)) 44 | .route("/inc", get(increment_handler)) 45 | .route("/reset", get(reset_handler)) 46 | .layer(session_layer); 47 | 48 | axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 49 | .serve(app.into_make_service()) 50 | .await 51 | .unwrap(); 52 | } 53 | -------------------------------------------------------------------------------- /examples/regenerate/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-regenerate" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | axum = "0.6.0" 9 | axum-sessions = { path = "../../" } 10 | 11 | [dependencies.rand] 12 | version = "0.8.5" 13 | features = ["min_const_gen"] 14 | 15 | [dependencies.tokio] 16 | version = "1.0" 17 | features = ["full"] 18 | -------------------------------------------------------------------------------- /examples/regenerate/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Run with 2 | //! 3 | //! ```not_rust 4 | //! cd examples && cargo run -p example-regenerate 5 | //! ``` 6 | 7 | use axum::{routing::get, Router}; 8 | use axum_sessions::{ 9 | async_session::MemoryStore, 10 | extractors::{ReadableSession, WritableSession}, 11 | SessionLayer, 12 | }; 13 | use rand::Rng; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | let store = MemoryStore::new(); 18 | let secret = rand::thread_rng().gen::<[u8; 128]>(); 19 | let session_layer = SessionLayer::new(store, &secret); 20 | 21 | async fn regenerate_handler(mut session: WritableSession) { 22 | // NB: This DOES NOT update the store, meaning that both sessions will still be 23 | // found. 24 | session.regenerate(); 25 | } 26 | 27 | async fn insert_handler(mut session: WritableSession) { 28 | session 29 | .insert("foo", 42) 30 | .expect("Could not store the answer."); 31 | } 32 | 33 | async fn handler(session: ReadableSession) -> String { 34 | session 35 | .get::("foo") 36 | .map(|answer| format!("{}", answer)) 37 | .unwrap_or_else(|| "Nothing in session yet; try /insert.".to_string()) 38 | } 39 | 40 | let app = Router::new() 41 | .route("/regenerate", get(regenerate_handler)) 42 | .route("/insert", get(insert_handler)) 43 | .route("/", get(handler)) 44 | .layer(session_layer); 45 | 46 | axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 47 | .serve(app.into_make_service()) 48 | .await 49 | .unwrap(); 50 | } 51 | -------------------------------------------------------------------------------- /examples/signin/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-signin" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | axum = "0.6.0" 9 | axum-sessions = { path = "../../" } 10 | 11 | [dependencies.rand] 12 | version = "0.8.5" 13 | features = ["min_const_gen"] 14 | 15 | [dependencies.tokio] 16 | version = "1.0" 17 | features = ["full"] 18 | -------------------------------------------------------------------------------- /examples/signin/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Run with 2 | //! 3 | //! ```not_rust 4 | //! cd examples && cargo run -p example-signin 5 | //! ``` 6 | 7 | use axum::{routing::get, Router}; 8 | use axum_sessions::{ 9 | async_session::MemoryStore, 10 | extractors::{ReadableSession, WritableSession}, 11 | SessionLayer, 12 | }; 13 | use rand::Rng; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | let store = MemoryStore::new(); 18 | let secret = rand::thread_rng().gen::<[u8; 128]>(); 19 | let session_layer = SessionLayer::new(store, &secret); 20 | 21 | async fn signin_handler(mut session: WritableSession) { 22 | session 23 | .insert("signed_in", true) 24 | .expect("Could not sign in."); 25 | } 26 | 27 | async fn signout_handler(mut session: WritableSession) { 28 | session.destroy(); 29 | } 30 | 31 | async fn protected_handler(session: ReadableSession) -> &'static str { 32 | if session.get::("signed_in").unwrap_or(false) { 33 | "Shh, it's secret!" 34 | } else { 35 | "Nothing to see here." 36 | } 37 | } 38 | 39 | let app = Router::new() 40 | .route("/signin", get(signin_handler)) 41 | .route("/signout", get(signout_handler)) 42 | .route("/protected", get(protected_handler)) 43 | .layer(session_layer); 44 | 45 | axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 46 | .serve(app.into_make_service()) 47 | .await 48 | .unwrap(); 49 | } 50 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | format_code_in_doc_comments = true 2 | format_strings = true 3 | imports_granularity = "Crate" 4 | group_imports = "StdExternalCrate" 5 | wrap_comments = true 6 | -------------------------------------------------------------------------------- /src/extractors.rs: -------------------------------------------------------------------------------- 1 | //! Extractors for sessions. 2 | 3 | use std::ops::{Deref, DerefMut}; 4 | 5 | use axum::{async_trait, extract::FromRequestParts, http::request::Parts, Extension}; 6 | use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard}; 7 | 8 | use crate::SessionHandle; 9 | 10 | /// An extractor which provides a readable session. Sessions may have many 11 | /// readers. 12 | #[derive(Debug)] 13 | pub struct ReadableSession { 14 | session: OwnedRwLockReadGuard, 15 | } 16 | 17 | impl Deref for ReadableSession { 18 | type Target = OwnedRwLockReadGuard; 19 | 20 | fn deref(&self) -> &Self::Target { 21 | &self.session 22 | } 23 | } 24 | 25 | #[async_trait] 26 | impl FromRequestParts for ReadableSession 27 | where 28 | S: Send + Sync, 29 | { 30 | type Rejection = std::convert::Infallible; 31 | 32 | async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { 33 | let Extension(session_handle): Extension = 34 | Extension::from_request_parts(parts, state) 35 | .await 36 | .expect("Session extension missing. Is the session layer installed?"); 37 | let session = session_handle.read_owned().await; 38 | 39 | Ok(Self { session }) 40 | } 41 | } 42 | 43 | /// An extractor which provides a writable session. Sessions may have only one 44 | /// writer. 45 | #[derive(Debug)] 46 | pub struct WritableSession { 47 | session: OwnedRwLockWriteGuard, 48 | } 49 | 50 | impl Deref for WritableSession { 51 | type Target = OwnedRwLockWriteGuard; 52 | 53 | fn deref(&self) -> &Self::Target { 54 | &self.session 55 | } 56 | } 57 | 58 | impl DerefMut for WritableSession { 59 | fn deref_mut(&mut self) -> &mut Self::Target { 60 | &mut self.session 61 | } 62 | } 63 | 64 | #[async_trait] 65 | impl FromRequestParts for WritableSession 66 | where 67 | S: Send + Sync, 68 | { 69 | type Rejection = std::convert::Infallible; 70 | 71 | async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { 72 | let Extension(session_handle): Extension = 73 | Extension::from_request_parts(parts, state) 74 | .await 75 | .expect("Session extension missing. Is the session layer installed?"); 76 | let session = session_handle.write_owned().await; 77 | 78 | Ok(Self { session }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # **Migration to `tower-sessions`** 2 | //! 3 | //! **Development of this crate has moved to 4 | //! [`tower-sessions`](https://github.com/maxcountryman/tower-sessions).** Please consider 5 | //! migrating. 6 | //! 7 | //! Numerous bugs and a significant design flaw with `axum-sessions` are 8 | //! addressed with `tower-sessions`. 9 | 10 | #![deny(missing_docs)] 11 | 12 | pub mod extractors; 13 | mod session; 14 | 15 | pub use async_session; 16 | pub use axum_extra::extract::cookie::SameSite; 17 | 18 | pub use self::session::{PersistencePolicy, Session, SessionHandle, SessionLayer}; 19 | -------------------------------------------------------------------------------- /src/session.rs: -------------------------------------------------------------------------------- 1 | // Much of this code is lifted directly from 2 | // `tide::sessions::middleware::SessionMiddleware`. See: https://github.com/http-rs/tide/blob/20fe435a9544c10f64245e883847fc3cd1d50538/src/sessions/middleware.rs 3 | 4 | use std::{ 5 | sync::Arc, 6 | task::{Context, Poll}, 7 | time::Duration, 8 | }; 9 | 10 | use async_session::{ 11 | base64, 12 | hmac::{Hmac, Mac, NewMac}, 13 | sha2::Sha256, 14 | SessionStore, 15 | }; 16 | use axum::{ 17 | http::{ 18 | header::{HeaderValue, COOKIE, SET_COOKIE}, 19 | Request, StatusCode, 20 | }, 21 | response::Response, 22 | }; 23 | use axum_extra::extract::cookie::{Cookie, Key, SameSite}; 24 | use futures::future::BoxFuture; 25 | use tokio::sync::RwLock; 26 | use tower::{Layer, Service}; 27 | 28 | const BASE64_DIGEST_LEN: usize = 44; 29 | 30 | /// A type alias which provides a handle to the underlying session. 31 | /// 32 | /// This is provided via [`http::Extensions`](axum::http::Extensions). Most 33 | /// applications will use the 34 | /// [`ReadableSession`](crate::extractors::ReadableSession) and 35 | /// [`WritableSession`](crate::extractors::WritableSession) extractors rather 36 | /// than using the handle directly. A notable exception is when using this 37 | /// library as a generic Tower middleware: such use cases will consume the 38 | /// handle directly. 39 | pub type SessionHandle = Arc>; 40 | 41 | /// Controls how the session data is persisted and created. 42 | #[derive(Clone)] 43 | pub enum PersistencePolicy { 44 | /// Always ping the storage layer and store empty "guest" sessions. 45 | Always, 46 | /// Do not store empty "guest" sessions, only ping the storage layer if 47 | /// the session data changed. 48 | ChangedOnly, 49 | /// Do not store empty "guest" sessions, always ping the storage layer for 50 | /// existing sessions. 51 | ExistingOnly, 52 | } 53 | 54 | /// Layer that provides cookie-based sessions. 55 | #[derive(Clone)] 56 | pub struct SessionLayer { 57 | store: Store, 58 | cookie_path: String, 59 | cookie_name: String, 60 | cookie_domain: Option, 61 | persistence_policy: PersistencePolicy, 62 | session_ttl: Option, 63 | same_site_policy: SameSite, 64 | http_only: bool, 65 | secure: bool, 66 | key: Key, 67 | } 68 | 69 | impl SessionLayer { 70 | /// Creates a layer which will attach a [`SessionHandle`] to requests via an 71 | /// extension. This session is derived from a cryptographically signed 72 | /// cookie. When the client sends a valid, known cookie then the session is 73 | /// hydrated from this. Otherwise a new cookie is created and returned in 74 | /// the response. 75 | /// 76 | /// The default behaviour is to enable "guest" sessions with 77 | /// [`PersistencePolicy::Always`]. 78 | /// 79 | /// # Panics 80 | /// 81 | /// `SessionLayer::new` will panic if the secret is less than 64 bytes. 82 | /// 83 | /// # Customization 84 | /// 85 | /// The configuration of the session may be adjusted according to the needs 86 | /// of your application: 87 | /// 88 | /// ```rust 89 | /// # use axum_sessions::{PersistencePolicy, SessionLayer, async_session::MemoryStore, SameSite}; 90 | /// # use std::time::Duration; 91 | /// SessionLayer::new( 92 | /// MemoryStore::new(), 93 | /// b"please do not hardcode your secret; instead use a 94 | /// cryptographically secure value", 95 | /// ) 96 | /// .with_cookie_name("your.cookie.name") 97 | /// .with_cookie_path("/some/path") 98 | /// .with_cookie_domain("www.example.com") 99 | /// .with_same_site_policy(SameSite::Lax) 100 | /// .with_session_ttl(Some(Duration::from_secs(60 * 5))) 101 | /// .with_persistence_policy(PersistencePolicy::Always) 102 | /// .with_http_only(true) 103 | /// .with_secure(true); 104 | /// ``` 105 | #[deprecated( 106 | since = "0.6.0", 107 | note = "Development of axum-sessions has moved to the tower-sessions crate. Please \ 108 | consider migrating." 109 | )] 110 | pub fn new(store: Store, secret: &[u8]) -> Self { 111 | if secret.len() < 64 { 112 | panic!("`secret` must be at least 64 bytes.") 113 | } 114 | 115 | Self { 116 | store, 117 | persistence_policy: PersistencePolicy::Always, 118 | cookie_path: "/".into(), 119 | cookie_name: "sid".into(), 120 | cookie_domain: None, 121 | same_site_policy: SameSite::Strict, 122 | session_ttl: Some(Duration::from_secs(24 * 60 * 60)), 123 | http_only: true, 124 | secure: true, 125 | key: Key::from(secret), 126 | } 127 | } 128 | 129 | /// When `true`, a session cookie will always be set. When `false` the 130 | /// session data must be modified in order for it to be set. Defaults to 131 | /// `true`. 132 | pub fn with_persistence_policy(mut self, policy: PersistencePolicy) -> Self { 133 | self.persistence_policy = policy; 134 | self 135 | } 136 | 137 | /// Sets a cookie for the session. Defaults to `"/"`. 138 | pub fn with_cookie_path(mut self, cookie_path: impl AsRef) -> Self { 139 | self.cookie_path = cookie_path.as_ref().to_owned(); 140 | self 141 | } 142 | 143 | /// Sets a cookie name for the session. Defaults to `"sid"`. 144 | pub fn with_cookie_name(mut self, cookie_name: impl AsRef) -> Self { 145 | self.cookie_name = cookie_name.as_ref().to_owned(); 146 | self 147 | } 148 | 149 | /// Sets a cookie domain for the session. Defaults to `None`. 150 | pub fn with_cookie_domain(mut self, cookie_domain: impl AsRef) -> Self { 151 | self.cookie_domain = Some(cookie_domain.as_ref().to_owned()); 152 | self 153 | } 154 | 155 | /// Decide if session is presented to the storage layer 156 | fn should_store(&self, cookie_value: &Option, session_data_changed: bool) -> bool { 157 | session_data_changed 158 | || matches!(self.persistence_policy, PersistencePolicy::Always) 159 | || (matches!(self.persistence_policy, PersistencePolicy::ExistingOnly) 160 | && cookie_value.is_some()) 161 | } 162 | 163 | /// Sets a cookie same site policy for the session. Defaults to 164 | /// `SameSite::Strict`. 165 | pub fn with_same_site_policy(mut self, policy: SameSite) -> Self { 166 | self.same_site_policy = policy; 167 | self 168 | } 169 | 170 | /// Sets a cookie time-to-live (ttl) for the session. Defaults to 171 | /// `Duration::from_secs(60 * 60 * 24)`; one day. 172 | pub fn with_session_ttl(mut self, session_ttl: Option) -> Self { 173 | self.session_ttl = session_ttl; 174 | self 175 | } 176 | 177 | /// Sets a cookie `HttpOnly` attribute for the session. Defaults to `true`. 178 | pub fn with_http_only(mut self, http_only: bool) -> Self { 179 | self.http_only = http_only; 180 | self 181 | } 182 | 183 | /// Sets a cookie secure attribute for the session. Defaults to `true`. 184 | pub fn with_secure(mut self, secure: bool) -> Self { 185 | self.secure = secure; 186 | self 187 | } 188 | 189 | async fn load_or_create(&self, cookie_value: Option) -> SessionHandle { 190 | let session = match cookie_value { 191 | Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(), 192 | None => None, 193 | }; 194 | 195 | Arc::new(RwLock::new( 196 | session 197 | .and_then(async_session::Session::validate) 198 | .unwrap_or_default(), 199 | )) 200 | } 201 | 202 | fn build_cookie(&self, cookie_value: String) -> Cookie<'static> { 203 | let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value) 204 | .http_only(self.http_only) 205 | .same_site(self.same_site_policy) 206 | .secure(self.secure) 207 | .path(self.cookie_path.clone()) 208 | .finish(); 209 | 210 | if let Some(ttl) = self.session_ttl { 211 | cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into())); 212 | } 213 | 214 | if let Some(cookie_domain) = self.cookie_domain.clone() { 215 | cookie.set_domain(cookie_domain) 216 | } 217 | 218 | self.sign_cookie(&mut cookie); 219 | 220 | cookie 221 | } 222 | 223 | fn build_removal_cookie(&self) -> Cookie<'static> { 224 | let cookie = Cookie::build(self.cookie_name.clone(), "") 225 | .http_only(true) 226 | .path(self.cookie_path.clone()); 227 | 228 | let mut cookie = if let Some(cookie_domain) = self.cookie_domain.clone() { 229 | cookie.domain(cookie_domain) 230 | } else { 231 | cookie 232 | } 233 | .finish(); 234 | 235 | cookie.make_removal(); 236 | 237 | self.sign_cookie(&mut cookie); 238 | 239 | cookie 240 | } 241 | 242 | // the following is reused verbatim from 243 | // https://github.com/SergioBenitez/cookie-rs/blob/master/src/secure/signed.rs#L33-L43 244 | /// Signs the cookie's value providing integrity and authenticity. 245 | fn sign_cookie(&self, cookie: &mut Cookie<'_>) { 246 | // Compute HMAC-SHA256 of the cookie's value. 247 | let mut mac = Hmac::::new_from_slice(self.key.signing()).expect("good key"); 248 | mac.update(cookie.value().as_bytes()); 249 | 250 | // Cookie's new value is [MAC | original-value]. 251 | let mut new_value = base64::encode(mac.finalize().into_bytes()); 252 | new_value.push_str(cookie.value()); 253 | cookie.set_value(new_value); 254 | } 255 | 256 | // the following is reused verbatim from 257 | // https://github.com/SergioBenitez/cookie-rs/blob/master/src/secure/signed.rs#L45-L63 258 | /// Given a signed value `str` where the signature is prepended to `value`, 259 | /// verifies the signed value and returns it. If there's a problem, returns 260 | /// an `Err` with a string describing the issue. 261 | fn verify_signature(&self, cookie_value: &str) -> Result { 262 | if cookie_value.len() < BASE64_DIGEST_LEN { 263 | return Err("length of value is <= BASE64_DIGEST_LEN"); 264 | } 265 | 266 | // Split [MAC | original-value] into its two parts. 267 | let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN); 268 | let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?; 269 | 270 | // Perform the verification. 271 | let mut mac = Hmac::::new_from_slice(self.key.signing()).expect("good key"); 272 | mac.update(value.as_bytes()); 273 | mac.verify(&digest) 274 | .map(|_| value.to_string()) 275 | .map_err(|_| "value did not verify") 276 | } 277 | } 278 | 279 | impl Layer for SessionLayer { 280 | type Service = Session; 281 | 282 | fn layer(&self, inner: Inner) -> Self::Service { 283 | Session { 284 | inner, 285 | layer: self.clone(), 286 | } 287 | } 288 | } 289 | 290 | /// Session service container. 291 | #[derive(Clone)] 292 | pub struct Session { 293 | inner: Inner, 294 | layer: SessionLayer, 295 | } 296 | 297 | impl Service> 298 | for Session 299 | where 300 | Inner: Service, Response = Response> + Clone + Send + 'static, 301 | ResBody: Send + 'static, 302 | ReqBody: Send + 'static, 303 | Inner::Future: Send + 'static, 304 | { 305 | type Response = Inner::Response; 306 | type Error = Inner::Error; 307 | type Future = BoxFuture<'static, Result>; 308 | 309 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 310 | self.inner.poll_ready(cx) 311 | } 312 | 313 | fn call(&mut self, mut request: Request) -> Self::Future { 314 | let session_layer = self.layer.clone(); 315 | 316 | // Multiple cookies may be all concatenated into a single Cookie header 317 | // separated with semicolons (HTTP/1.1 behaviour) or into multiple separate 318 | // Cookie headers (HTTP/2 behaviour). Search for the session cookie from 319 | // all Cookie headers, assuming both forms are possible 320 | let cookie_value = request 321 | .headers() 322 | .get_all(COOKIE) 323 | .iter() 324 | .filter_map(|cookie_header| cookie_header.to_str().ok()) 325 | .flat_map(|cookie_header| cookie_header.split(';')) 326 | .filter_map(|cookie_header| Cookie::parse_encoded(cookie_header.trim()).ok()) 327 | .filter(|cookie| cookie.name() == session_layer.cookie_name) 328 | .find_map(|cookie| self.layer.verify_signature(cookie.value()).ok()); 329 | 330 | let inner = self.inner.clone(); 331 | let mut inner = std::mem::replace(&mut self.inner, inner); 332 | Box::pin(async move { 333 | let session_handle = session_layer.load_or_create(cookie_value.clone()).await; 334 | 335 | let mut session = session_handle.write().await; 336 | if let Some(ttl) = session_layer.session_ttl { 337 | (*session).expire_in(ttl); 338 | } 339 | drop(session); 340 | 341 | request.extensions_mut().insert(session_handle.clone()); 342 | let mut response = inner.call(request).await?; 343 | 344 | let session = session_handle.read().await; 345 | let (session_is_destroyed, session_data_changed) = 346 | (session.is_destroyed(), session.data_changed()); 347 | drop(session); 348 | 349 | // Pull out the session so we can pass it to the store without `Clone` blowing 350 | // away the `cookie_value`. 351 | let session = RwLock::into_inner( 352 | Arc::try_unwrap(session_handle).expect("Session handle still has owners."), 353 | ); 354 | if session_is_destroyed { 355 | if let Err(e) = session_layer.store.destroy_session(session).await { 356 | tracing::error!("Failed to destroy session: {:?}", e); 357 | *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; 358 | } 359 | 360 | let removal_cookie = session_layer.build_removal_cookie(); 361 | 362 | response.headers_mut().append( 363 | SET_COOKIE, 364 | HeaderValue::from_str(&removal_cookie.to_string()).unwrap(), 365 | ); 366 | 367 | // Store if 368 | // - We have guest sessions 369 | // - We received a valid cookie and we use the `ExistingOnly` 370 | // policy. 371 | // - If we use the `ChangedOnly` policy, only 372 | // `session.data_changed()` should trigger this branch. 373 | } else if session_layer.should_store(&cookie_value, session_data_changed) { 374 | match session_layer.store.store_session(session).await { 375 | Ok(Some(cookie_value)) => { 376 | let cookie = session_layer.build_cookie(cookie_value); 377 | response.headers_mut().append( 378 | SET_COOKIE, 379 | HeaderValue::from_str(&cookie.to_string()).unwrap(), 380 | ); 381 | } 382 | 383 | Ok(None) => {} 384 | 385 | Err(e) => { 386 | tracing::error!("Failed to reach session storage: {:?}", e); 387 | *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; 388 | } 389 | } 390 | } 391 | 392 | Ok(response) 393 | }) 394 | } 395 | } 396 | 397 | #[cfg(test)] 398 | mod tests { 399 | use async_session::{ 400 | serde::{Deserialize, Serialize}, 401 | serde_json, 402 | }; 403 | use axum::http::{Request, Response}; 404 | use http::{ 405 | header::{COOKIE, SET_COOKIE}, 406 | HeaderValue, StatusCode, 407 | }; 408 | use hyper::Body; 409 | use rand::Rng; 410 | use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; 411 | 412 | use super::PersistencePolicy; 413 | use crate::{async_session::MemoryStore, SessionHandle, SessionLayer}; 414 | 415 | #[derive(Deserialize, Serialize, PartialEq, Debug)] 416 | struct Counter { 417 | counter: i32, 418 | } 419 | 420 | enum ExpectedResult { 421 | Some, 422 | None, 423 | } 424 | 425 | #[tokio::test] 426 | async fn sets_session_cookie() { 427 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 428 | let store = MemoryStore::new(); 429 | let session_layer = SessionLayer::new(store, &secret); 430 | let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo); 431 | 432 | let request = Request::get("/").body(Body::empty()).unwrap(); 433 | 434 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 435 | assert_eq!(res.status(), StatusCode::OK); 436 | 437 | assert!(res 438 | .headers() 439 | .get(SET_COOKIE) 440 | .unwrap() 441 | .to_str() 442 | .unwrap() 443 | .starts_with("sid=")) 444 | } 445 | 446 | #[tokio::test] 447 | async fn uses_valid_session() { 448 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 449 | let store = MemoryStore::new(); 450 | let session_layer = SessionLayer::new(store, &secret); 451 | let mut service = ServiceBuilder::new() 452 | .layer(session_layer) 453 | .service_fn(increment); 454 | 455 | let request = Request::get("/").body(Body::empty()).unwrap(); 456 | 457 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 458 | let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); 459 | 460 | assert_eq!(res.status(), StatusCode::OK); 461 | 462 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 463 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 464 | assert_eq!(counter, Counter { counter: 0 }); 465 | 466 | let mut request = Request::get("/").body(Body::empty()).unwrap(); 467 | request 468 | .headers_mut() 469 | .insert(COOKIE, session_cookie.to_owned()); 470 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 471 | assert_eq!(res.status(), StatusCode::OK); 472 | 473 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 474 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 475 | assert_eq!(counter, Counter { counter: 1 }); 476 | } 477 | 478 | #[tokio::test] 479 | async fn multiple_cookies_in_single_header() { 480 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 481 | let store = MemoryStore::new(); 482 | let session_layer = SessionLayer::new(store, &secret); 483 | let mut service = ServiceBuilder::new() 484 | .layer(session_layer) 485 | .service_fn(increment); 486 | 487 | let request = Request::get("/").body(Body::empty()).unwrap(); 488 | 489 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 490 | let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); 491 | 492 | // build a Cookie header that contains two cookies: an unrelated dummy cookie, 493 | // and the given session cookie 494 | let request_cookie = 495 | HeaderValue::from_str(&format!("key=value; {}", session_cookie.to_str().unwrap())) 496 | .unwrap(); 497 | 498 | assert_eq!(res.status(), StatusCode::OK); 499 | 500 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 501 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 502 | assert_eq!(counter, Counter { counter: 0 }); 503 | 504 | let mut request = Request::get("/").body(Body::empty()).unwrap(); 505 | request.headers_mut().insert(COOKIE, request_cookie); 506 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 507 | assert_eq!(res.status(), StatusCode::OK); 508 | 509 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 510 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 511 | assert_eq!(counter, Counter { counter: 1 }); 512 | } 513 | 514 | #[tokio::test] 515 | async fn multiple_cookie_headers() { 516 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 517 | let store = MemoryStore::new(); 518 | let session_layer = SessionLayer::new(store, &secret); 519 | let mut service = ServiceBuilder::new() 520 | .layer(session_layer) 521 | .service_fn(increment); 522 | 523 | let request = Request::get("/").body(Body::empty()).unwrap(); 524 | 525 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 526 | let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); 527 | let dummy_cookie = HeaderValue::from_str("key=value").unwrap(); 528 | 529 | assert_eq!(res.status(), StatusCode::OK); 530 | 531 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 532 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 533 | assert_eq!(counter, Counter { counter: 0 }); 534 | 535 | let mut request = Request::get("/").body(Body::empty()).unwrap(); 536 | request.headers_mut().append(COOKIE, dummy_cookie); 537 | request.headers_mut().append(COOKIE, session_cookie); 538 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 539 | assert_eq!(res.status(), StatusCode::OK); 540 | 541 | let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; 542 | let counter: Counter = serde_json::from_slice(json_bs).unwrap(); 543 | assert_eq!(counter, Counter { counter: 1 }); 544 | } 545 | 546 | #[tokio::test] 547 | async fn no_cookie_stored_when_no_session_is_required() { 548 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 549 | let store = MemoryStore::new(); 550 | let session_layer = SessionLayer::new(store, &secret) 551 | .with_persistence_policy(PersistencePolicy::ChangedOnly); 552 | let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo); 553 | 554 | let request = Request::get("/").body(Body::empty()).unwrap(); 555 | 556 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 557 | assert_eq!(res.status(), StatusCode::OK); 558 | 559 | assert!(res.headers().get(SET_COOKIE).is_none()); 560 | } 561 | 562 | async fn invalid_session_check_cookie_result( 563 | persistence_policy: PersistencePolicy, 564 | change_data: bool, 565 | expect_cookie_header: (ExpectedResult, ExpectedResult), 566 | ) { 567 | let (expect_cookie_header_first, expect_cookie_header_second) = expect_cookie_header; 568 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 569 | let store = MemoryStore::new(); 570 | let session_layer = 571 | SessionLayer::new(store, &secret).with_persistence_policy(persistence_policy); 572 | let mut service = ServiceBuilder::new() 573 | .layer(&session_layer) 574 | .service_fn(echo_read_session); 575 | 576 | let request = Request::get("/").body(Body::empty()).unwrap(); 577 | 578 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 579 | assert_eq!(res.status(), StatusCode::OK); 580 | 581 | match expect_cookie_header_first { 582 | ExpectedResult::Some => assert!( 583 | res.headers().get(SET_COOKIE).is_some(), 584 | "Set-Cookie must be present for first response" 585 | ), 586 | ExpectedResult::None => assert!( 587 | res.headers().get(SET_COOKIE).is_none(), 588 | "Set-Cookie must not be present for first response" 589 | ), 590 | } 591 | 592 | let mut service = 593 | ServiceBuilder::new() 594 | .layer(session_layer) 595 | .service_fn(move |req| async move { 596 | if change_data { 597 | echo_with_session_change(req).await 598 | } else { 599 | echo_read_session(req).await 600 | } 601 | }); 602 | let mut request = Request::get("/").body(Body::empty()).unwrap(); 603 | request 604 | .headers_mut() 605 | .insert(COOKIE, "sid=aW52YWxpZC1zZXNzaW9uLWlk".parse().unwrap()); 606 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 607 | match expect_cookie_header_second { 608 | ExpectedResult::Some => assert!( 609 | res.headers().get(SET_COOKIE).is_some(), 610 | "Set-Cookie must be present for second response" 611 | ), 612 | ExpectedResult::None => assert!( 613 | res.headers().get(SET_COOKIE).is_none(), 614 | "Set-Cookie must not be present for second response" 615 | ), 616 | } 617 | } 618 | 619 | #[tokio::test] 620 | async fn invalid_session_always_sets_guest_cookie() { 621 | invalid_session_check_cookie_result( 622 | PersistencePolicy::Always, 623 | false, 624 | (ExpectedResult::Some, ExpectedResult::Some), 625 | ) 626 | .await; 627 | } 628 | 629 | #[tokio::test] 630 | async fn invalid_session_sets_new_session_cookie_when_data_changes() { 631 | invalid_session_check_cookie_result( 632 | PersistencePolicy::ExistingOnly, 633 | true, 634 | (ExpectedResult::None, ExpectedResult::Some), 635 | ) 636 | .await; 637 | } 638 | 639 | #[tokio::test] 640 | async fn invalid_session_sets_no_cookie_when_no_data_changes() { 641 | invalid_session_check_cookie_result( 642 | PersistencePolicy::ExistingOnly, 643 | false, 644 | (ExpectedResult::None, ExpectedResult::None), 645 | ) 646 | .await; 647 | } 648 | 649 | #[tokio::test] 650 | async fn invalid_session_changedonly_sets_cookie_when_changed() { 651 | invalid_session_check_cookie_result( 652 | PersistencePolicy::ChangedOnly, 653 | true, 654 | (ExpectedResult::None, ExpectedResult::Some), 655 | ) 656 | .await; 657 | } 658 | 659 | #[tokio::test] 660 | async fn destroyed_sessions_sets_removal_cookie() { 661 | let secret = rand::thread_rng().gen::<[u8; 64]>(); 662 | let store = MemoryStore::new(); 663 | let session_layer = SessionLayer::new(store, &secret); 664 | let mut service = ServiceBuilder::new() 665 | .layer(session_layer) 666 | .service_fn(destroy); 667 | 668 | let request = Request::get("/").body(Body::empty()).unwrap(); 669 | 670 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 671 | assert_eq!(res.status(), StatusCode::OK); 672 | 673 | let session_cookie = res 674 | .headers() 675 | .get(SET_COOKIE) 676 | .unwrap() 677 | .to_str() 678 | .unwrap() 679 | .to_string(); 680 | let mut request = Request::get("/destroy").body(Body::empty()).unwrap(); 681 | request 682 | .headers_mut() 683 | .insert(COOKIE, session_cookie.parse().unwrap()); 684 | let res = service.ready().await.unwrap().call(request).await.unwrap(); 685 | assert_eq!( 686 | res.headers() 687 | .get(SET_COOKIE) 688 | .unwrap() 689 | .to_str() 690 | .unwrap() 691 | .len(), 692 | 116 693 | ); 694 | } 695 | 696 | #[test] 697 | #[should_panic] 698 | fn too_short_secret() { 699 | let store = MemoryStore::new(); 700 | SessionLayer::new(store, b""); 701 | } 702 | 703 | async fn echo(req: Request) -> Result, BoxError> { 704 | Ok(Response::new(req.into_body())) 705 | } 706 | 707 | async fn echo_read_session(req: Request) -> Result, BoxError> { 708 | { 709 | let session_handle = req.extensions().get::().unwrap(); 710 | let session = session_handle.write().await; 711 | let _ = session.get::("signed_in").unwrap_or_default(); 712 | } 713 | Ok(Response::new(req.into_body())) 714 | } 715 | 716 | async fn echo_with_session_change(req: Request) -> Result, BoxError> { 717 | { 718 | let session_handle = req.extensions().get::().unwrap(); 719 | let mut session = session_handle.write().await; 720 | session.insert("signed_in", true).unwrap(); 721 | } 722 | Ok(Response::new(req.into_body())) 723 | } 724 | 725 | async fn destroy(req: Request) -> Result, BoxError> { 726 | // Destroy the session if we received a session cookie. 727 | if req.headers().get(COOKIE).is_some() { 728 | let session_handle = req.extensions().get::().unwrap(); 729 | let mut session = session_handle.write().await; 730 | session.destroy(); 731 | } 732 | 733 | Ok(Response::new(req.into_body())) 734 | } 735 | 736 | async fn increment(mut req: Request) -> Result, BoxError> { 737 | let mut counter = 0; 738 | 739 | { 740 | let session_handle = req.extensions().get::().unwrap(); 741 | let mut session = session_handle.write().await; 742 | counter = session 743 | .get("counter") 744 | .map(|count: i32| count + 1) 745 | .unwrap_or(counter); 746 | session.insert("counter", counter).unwrap(); 747 | } 748 | 749 | let body = serde_json::to_string(&Counter { counter }).unwrap(); 750 | *req.body_mut() = Body::from(body); 751 | 752 | Ok(Response::new(req.into_body())) 753 | } 754 | } 755 | --------------------------------------------------------------------------------