├── .gitignore ├── src ├── mpsc │ ├── mod.rs │ ├── semaphore.rs │ ├── unbounded.rs │ ├── bounded.rs │ ├── block.rs │ └── chan.rs ├── lib.rs ├── wake_list.rs ├── once_cell.rs ├── linked_list.rs ├── oneshot.rs ├── broadcast.rs └── semaphore.rs ├── README.md ├── Cargo.toml ├── LICENSE-MIT ├── LICENSE-THIRD-PARTY └── LICENSE-APACHE /.gitignore: -------------------------------------------------------------------------------- 1 | # Rust 2 | /target 3 | Cargo.lock 4 | 5 | # IDE 6 | .vscode 7 | .idea 8 | -------------------------------------------------------------------------------- /src/mpsc/mod.rs: -------------------------------------------------------------------------------- 1 | mod block; 2 | mod chan; 3 | mod semaphore; 4 | 5 | pub mod bounded; 6 | pub mod unbounded; 7 | 8 | pub use chan::{SendError, TryRecvError}; 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # local-sync 2 | Local-sync is a crate providing data structures for sync within the local thread. 3 | 4 | ## mpsc 5 | Mpsc includes bounded and unbounded channel. 6 | 7 | ## Once Cell 8 | Once cell like once in golang. 9 | 10 | ## Oneshot 11 | Oneshot channel which can be send and receive data only one time. Also, it can be used as a notification method. 12 | 13 | ## Semaphore 14 | You can async wait permits and add permits with Semaphore. 15 | 16 | ## Licenses 17 | Local-sync is licensed under the MIT license or Apache license. 18 | 19 | During developing we referenced a lot from Tokio. We would like to thank the authors of the projects. 20 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Local Sync is a crate that providing non-thread-safe data structures useful 2 | //! for async programming. 3 | //! If you use a runtime with thread-per-core model(for example the Monoio), you 4 | //! may use this crate to avoid the cost of communicating across threads. 5 | 6 | // shared basic data structure 7 | mod linked_list; 8 | mod wake_list; 9 | 10 | // Semaphore 11 | pub mod semaphore; 12 | // BoundedChannel and UnboundedChannel 13 | pub mod mpsc; 14 | 15 | // OneshotChannel 16 | pub mod oneshot; 17 | 18 | // OnceCell 19 | mod once_cell; 20 | pub use once_cell::{OnceCell, SetError}; 21 | 22 | // Broadcast 23 | pub mod broadcast; 24 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "local-sync" 3 | version = "0.1.1" 4 | description = "Non-threadsafe data structure for async usage" 5 | authors = ["ihciah "] 6 | repository = "https://github.com/monoio-rs/local-sync" 7 | keywords = ["sync", "channel", "local", "futures", "local-channel"] 8 | license = "MIT OR Apache-2.0" 9 | edition = "2021" 10 | 11 | [dependencies] 12 | futures-core = { version = "0.3", default-features = false } 13 | futures-sink = { version = "0.3", default-features = false } 14 | futures-util = { version = "0.3", default-features = false } 15 | 16 | [dev-dependencies] 17 | monoio = { version = "0.1.0", features = ["macros"] } 18 | -------------------------------------------------------------------------------- /src/mpsc/semaphore.rs: -------------------------------------------------------------------------------- 1 | use std::cell::UnsafeCell; 2 | 3 | pub trait Semaphore { 4 | fn add_permits(&self, n: usize); 5 | fn close(&self); 6 | fn is_closed(&self) -> bool; 7 | } 8 | 9 | impl Semaphore for crate::semaphore::Inner { 10 | fn add_permits(&self, n: usize) { 11 | self.release(n); 12 | } 13 | 14 | fn close(&self) { 15 | crate::semaphore::Inner::close(self); 16 | } 17 | 18 | fn is_closed(&self) -> bool { 19 | crate::semaphore::Inner::is_closed(self) 20 | } 21 | } 22 | 23 | pub struct Unlimited { 24 | closed: UnsafeCell, 25 | } 26 | 27 | impl Unlimited { 28 | pub fn new() -> Self { 29 | Self { 30 | closed: UnsafeCell::new(false), 31 | } 32 | } 33 | } 34 | 35 | impl Semaphore for Unlimited { 36 | fn add_permits(&self, _: usize) {} 37 | 38 | fn close(&self) { 39 | unsafe { 40 | *self.closed.get() = true; 41 | } 42 | } 43 | 44 | fn is_closed(&self) -> bool { 45 | unsafe { *self.closed.get() } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 local-sync and Monoio Contributors 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /LICENSE-THIRD-PARTY: -------------------------------------------------------------------------------- 1 | Third party project code used by this project: Tokio. 2 | 3 | =============================================================================== 4 | 5 | Tokio 6 | https://github.com/tokio-rs/tokio/blob/master/LICENSE 7 | 8 | Copyright (c) 2021 Tokio Contributors 9 | 10 | Permission is hereby granted, free of charge, to any 11 | person obtaining a copy of this software and associated 12 | documentation files (the "Software"), to deal in the 13 | Software without restriction, including without 14 | limitation the rights to use, copy, modify, merge, 15 | publish, distribute, sublicense, and/or sell copies of 16 | the Software, and to permit persons to whom the Software 17 | is furnished to do so, subject to the following 18 | conditions: 19 | 20 | The above copyright notice and this permission notice 21 | shall be included in all copies or substantial portions 22 | of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 25 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 26 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 27 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 28 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 29 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 30 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 31 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 32 | DEALINGS IN THE SOFTWARE. 33 | -------------------------------------------------------------------------------- /src/wake_list.rs: -------------------------------------------------------------------------------- 1 | //! Wake list borrowed from tokio. 2 | 3 | use core::mem::MaybeUninit; 4 | use core::ptr; 5 | use std::task::Waker; 6 | 7 | const NUM_WAKERS: usize = 32; 8 | 9 | pub(crate) struct WakeList { 10 | inner: [MaybeUninit; NUM_WAKERS], 11 | curr: usize, 12 | } 13 | 14 | impl WakeList { 15 | pub(crate) fn new() -> Self { 16 | Self { 17 | inner: unsafe { 18 | // safety: Create an uninitialized array of `MaybeUninit`. The 19 | // `assume_init` is safe because the type we are claiming to 20 | // have initialized here is a bunch of `MaybeUninit`s, which do 21 | // not require initialization. 22 | MaybeUninit::uninit().assume_init() 23 | }, 24 | curr: 0, 25 | } 26 | } 27 | 28 | #[inline] 29 | pub(crate) fn can_push(&self) -> bool { 30 | self.curr < NUM_WAKERS 31 | } 32 | 33 | pub(crate) fn push(&mut self, val: Waker) { 34 | debug_assert!(self.can_push()); 35 | 36 | self.inner[self.curr] = MaybeUninit::new(val); 37 | self.curr += 1; 38 | } 39 | 40 | pub(crate) fn wake_all(&mut self) { 41 | assert!(self.curr <= NUM_WAKERS); 42 | while self.curr > 0 { 43 | self.curr -= 1; 44 | let waker = unsafe { ptr::read(self.inner[self.curr].as_mut_ptr()) }; 45 | waker.wake(); 46 | } 47 | } 48 | } 49 | 50 | impl Drop for WakeList { 51 | fn drop(&mut self) { 52 | let slice = ptr::slice_from_raw_parts_mut(self.inner.as_mut_ptr() as *mut Waker, self.curr); 53 | unsafe { ptr::drop_in_place(slice) }; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/mpsc/unbounded.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | chan::{self, SendError, TryRecvError}, 3 | semaphore::Unlimited, 4 | }; 5 | use futures_core::Stream; 6 | use futures_util::future::poll_fn; 7 | use std::pin::Pin; 8 | use std::task::{Context, Poll}; 9 | 10 | pub struct Tx(chan::Tx); 11 | 12 | pub struct Rx(chan::Rx); 13 | 14 | pub fn channel() -> (Tx, Rx) { 15 | let semaphore = Unlimited::new(); 16 | let (tx, rx) = chan::channel(semaphore); 17 | (Tx(tx), Rx(rx)) 18 | } 19 | 20 | impl Tx { 21 | pub fn send(&self, value: T) -> Result<(), SendError> { 22 | self.0.send(value) 23 | } 24 | 25 | pub fn is_closed(&self) -> bool { 26 | self.0.is_closed() 27 | } 28 | 29 | pub fn same_channel(&self, other: &Self) -> bool { 30 | self.0.same_channel(&other.0) 31 | } 32 | } 33 | 34 | impl Clone for Tx { 35 | fn clone(&self) -> Self { 36 | Self(self.0.clone()) 37 | } 38 | } 39 | 40 | impl Rx { 41 | pub async fn recv(&mut self) -> Option { 42 | poll_fn(|cx| self.poll_recv(cx)).await 43 | } 44 | 45 | pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { 46 | self.0.recv(cx) 47 | } 48 | 49 | pub fn try_recv(&mut self) -> Result { 50 | self.0.try_recv() 51 | } 52 | 53 | pub fn close(&mut self) { 54 | self.0.close() 55 | } 56 | } 57 | 58 | impl Stream for Rx { 59 | type Item = T; 60 | 61 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 62 | self.poll_recv(cx) 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use super::channel; 69 | 70 | #[monoio::test] 71 | async fn tets_unbounded_channel() { 72 | let (tx, mut rx) = channel(); 73 | tx.send(1).unwrap(); 74 | assert_eq!(rx.recv().await.unwrap(), 1); 75 | 76 | drop(tx); 77 | assert_eq!(rx.recv().await, None); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/mpsc/bounded.rs: -------------------------------------------------------------------------------- 1 | use super::chan::{self, SendError, TryRecvError}; 2 | use crate::semaphore::Inner; 3 | use futures_core::Stream; 4 | use futures_util::future::poll_fn; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | 8 | pub struct Tx(chan::Tx); 9 | 10 | pub struct Rx(chan::Rx); 11 | 12 | pub fn channel(buffer: usize) -> (Tx, Rx) { 13 | let semaphore = Inner::new(buffer); 14 | let (tx, rx) = chan::channel(semaphore); 15 | (Tx(tx), Rx(rx)) 16 | } 17 | 18 | impl Tx { 19 | pub async fn send(&self, value: T) -> Result<(), SendError> { 20 | // acquire semaphore first 21 | self.0 22 | .chan 23 | .semaphore 24 | .acquire(1) 25 | .await 26 | .map_err(|_| SendError::RxClosed)?; 27 | self.0.send(value) 28 | } 29 | 30 | pub fn is_closed(&self) -> bool { 31 | self.0.is_closed() 32 | } 33 | 34 | pub fn same_channel(&self, other: &Self) -> bool { 35 | self.0.same_channel(&other.0) 36 | } 37 | } 38 | 39 | impl Clone for Tx { 40 | fn clone(&self) -> Self { 41 | Self(self.0.clone()) 42 | } 43 | } 44 | 45 | impl Rx { 46 | pub async fn recv(&mut self) -> Option { 47 | poll_fn(|cx| self.poll_recv(cx)).await 48 | } 49 | 50 | pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { 51 | self.0.recv(cx) 52 | } 53 | 54 | pub fn try_recv(&mut self) -> Result { 55 | self.0.try_recv() 56 | } 57 | 58 | pub fn close(&mut self) { 59 | self.0.close() 60 | } 61 | } 62 | 63 | impl Stream for Rx { 64 | type Item = T; 65 | 66 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 67 | self.poll_recv(cx) 68 | } 69 | } 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use super::channel; 74 | 75 | #[monoio::test] 76 | async fn tets_bounded_channel() { 77 | let (tx, mut rx) = channel(1); 78 | tx.send(1).await.unwrap(); 79 | assert_eq!(rx.recv().await.unwrap(), 1); 80 | 81 | drop(tx); 82 | assert_eq!(rx.recv().await, None); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/mpsc/block.rs: -------------------------------------------------------------------------------- 1 | use std::cell::UnsafeCell; 2 | use std::mem::MaybeUninit; 3 | use std::ptr::NonNull; 4 | 5 | const BLOCK_CAP: usize = 32; 6 | pub(crate) struct Block { 7 | /// The next block in the linked list. 8 | next: Option>>, 9 | 10 | /// Array containing values pushed into the block. 11 | values: UnsafeCell<[MaybeUninit; BLOCK_CAP]>, 12 | 13 | /// Head index. 14 | begin: usize, 15 | 16 | /// Tail index. 17 | end: usize, 18 | } 19 | 20 | impl Block { 21 | pub(crate) fn new() -> Self { 22 | Self { 23 | next: None, 24 | values: UnsafeCell::new(unsafe { MaybeUninit::uninit().assume_init() }), 25 | begin: 0, 26 | end: 0, 27 | } 28 | } 29 | 30 | #[allow(unused)] 31 | pub(crate) fn len(&self) -> usize { 32 | self.end - self.begin 33 | } 34 | 35 | #[allow(unused)] 36 | pub(crate) fn is_empty(&self) -> bool { 37 | self.end == self.begin 38 | } 39 | 40 | pub(crate) fn can_write(&self) -> bool { 41 | self.end < BLOCK_CAP 42 | } 43 | 44 | pub(crate) unsafe fn reset(&mut self) { 45 | self.next = None; 46 | self.begin = 0; 47 | self.end = 0; 48 | } 49 | } 50 | 51 | pub(crate) struct Queue { 52 | /// The block to read data from. 53 | head: NonNull>, 54 | /// The block to write data to. It must be a valid block that has space. 55 | tail: NonNull>, 56 | /// Data length 57 | len: usize, 58 | } 59 | 60 | impl Queue { 61 | pub(crate) fn new() -> Self { 62 | let block = Box::new(Block::new()); 63 | let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(block)) }; 64 | Self { 65 | head: ptr, 66 | tail: ptr, 67 | len: 0, 68 | } 69 | } 70 | 71 | pub(crate) fn len(&self) -> usize { 72 | self.len 73 | } 74 | 75 | pub(crate) fn is_empty(&self) -> bool { 76 | self.len == 0 77 | } 78 | 79 | /// Push data into queue. 80 | /// # Safety: Make sure the current capacity is allowed. 81 | pub(crate) unsafe fn push_unchecked(&mut self, value: T) { 82 | // Write data and update block end index 83 | let blk = self.tail.as_mut(); 84 | let offset = blk.end; 85 | blk.end += 1; 86 | (*blk.values.get())[offset] = MaybeUninit::new(value); 87 | 88 | // Update queue length and make sure tail point to a valid block(not full) 89 | self.len += 1; 90 | if !blk.can_write() { 91 | if let Some(ptr) = blk.next { 92 | // just move the tail ptr 93 | self.tail = ptr; 94 | } else { 95 | // alloc a new block 96 | let block = Box::new(Block::new()); 97 | let ptr = NonNull::new_unchecked(Box::into_raw(block)); 98 | blk.next = Some(ptr); 99 | // move the tail ptr 100 | self.tail = ptr; 101 | } 102 | } 103 | } 104 | 105 | /// Pop data out. 106 | /// # Safety: Make sure there is still some data inside. 107 | pub(crate) unsafe fn pop_unchecked(&mut self) -> T { 108 | // Read data and update block read index 109 | let blk = self.head.as_mut(); 110 | debug_assert!(!blk.is_empty(), "head block is empty while pop_unchecked"); 111 | let offset = blk.begin; 112 | blk.begin += 1; 113 | let value = std::mem::replace(&mut (*blk.values.get())[offset], MaybeUninit::uninit()); 114 | 115 | // Update queue length and try to recycle the head block if its empty. 116 | self.len -= 1; 117 | if blk.begin == BLOCK_CAP { 118 | // Update head of queue. 119 | self.head = blk.next.expect("no next block while pop_unchecked"); 120 | // Move block to the tail and reset it. 121 | let tail = self.tail.as_mut(); 122 | let free_blocks = tail.next; 123 | blk.reset(); 124 | blk.next = free_blocks; 125 | tail.next = Some(NonNull::new_unchecked(blk)); 126 | } 127 | value.assume_init() 128 | } 129 | 130 | /// Free all blocks. 131 | /// # Safety: Free blocks and drop. Must make sure you drop all elements first. 132 | pub(crate) unsafe fn free_blocks(&mut self) { 133 | debug_assert_ne!(self.head, NonNull::dangling()); 134 | let mut cur = Some(self.head); 135 | 136 | #[cfg(debug_assertions)] 137 | { 138 | // to trigger the debug assert above so as to catch that we 139 | // don't call `free_blocks` more than once. 140 | self.head = NonNull::dangling(); 141 | } 142 | 143 | while let Some(block) = cur { 144 | cur = block.as_ref().next; 145 | drop(Box::from_raw(block.as_ptr())); 146 | } 147 | } 148 | } 149 | 150 | #[cfg(test)] 151 | mod tests { 152 | use super::Queue; 153 | 154 | #[test] 155 | fn test_simple_push_pop() { 156 | let mut queue = Queue::new(); 157 | unsafe { 158 | queue.push_unchecked(1); 159 | queue.push_unchecked(2); 160 | queue.push_unchecked(3); 161 | assert_eq!(queue.len(), 3); 162 | assert_eq!(queue.pop_unchecked(), 1); 163 | assert_eq!(queue.pop_unchecked(), 2); 164 | assert_eq!(queue.pop_unchecked(), 3); 165 | assert_eq!(queue.len(), 0); 166 | } 167 | } 168 | 169 | #[test] 170 | fn test_across_block_push_pop() { 171 | let mut queue = Queue::new(); 172 | unsafe { 173 | for _ in 0..4 { 174 | for idx in 0..1024 { 175 | queue.push_unchecked(idx); 176 | assert_eq!(queue.len(), idx + 1); 177 | } 178 | for idx in 0..1024 { 179 | assert_eq!(queue.pop_unchecked(), idx); 180 | assert_eq!(queue.len(), 1023 - idx); 181 | } 182 | } 183 | } 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/mpsc/chan.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cell::{Cell, RefCell}, 3 | error::Error, 4 | fmt, 5 | rc::Rc, 6 | task::{Context, Poll, Waker}, 7 | }; 8 | 9 | use super::{block::Queue, semaphore::Semaphore}; 10 | 11 | pub(crate) fn channel(semaphore: S) -> (Tx, Rx) 12 | where 13 | S: Semaphore, 14 | { 15 | let chan = Rc::new(Chan::new(semaphore)); 16 | let tx = Tx::new(chan.clone()); 17 | let rx = Rx::new(chan); 18 | (tx, rx) 19 | } 20 | 21 | pub(crate) struct Chan { 22 | queue: RefCell>, 23 | pub(crate) semaphore: S, 24 | rx_waker: RefCell>, 25 | tx_count: Cell, 26 | } 27 | 28 | /// Error returned by `try_recv`. 29 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] 30 | pub enum TryRecvError { 31 | /// This **channel** is currently empty, but the **Sender**(s) have not yet 32 | /// disconnected, so data may yet become available. 33 | Empty, 34 | /// The **channel**'s sending half has become disconnected, and there will 35 | /// never be any more data received on it. 36 | Disconnected, 37 | } 38 | 39 | impl fmt::Display for TryRecvError { 40 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 41 | match *self { 42 | TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), 43 | TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), 44 | } 45 | } 46 | } 47 | 48 | impl Error for TryRecvError {} 49 | 50 | impl Chan 51 | where 52 | S: Semaphore, 53 | { 54 | pub(crate) fn new(semaphore: S) -> Self { 55 | let queue = RefCell::new(Queue::new()); 56 | Self { 57 | queue, 58 | semaphore, 59 | rx_waker: RefCell::new(None), 60 | tx_count: Cell::new(0), 61 | } 62 | } 63 | } 64 | 65 | impl Drop for Chan 66 | where 67 | S: Semaphore, 68 | { 69 | fn drop(&mut self) { 70 | // consume all elements: 71 | // we cleared all elements on Rx drop, but there may still some 72 | // values sent after permits added. 73 | let mut queue = self.queue.borrow_mut(); 74 | while !queue.is_empty() { 75 | drop(unsafe { queue.pop_unchecked() }); 76 | } 77 | // drop all blocks of queue 78 | unsafe { queue.free_blocks() } 79 | } 80 | } 81 | 82 | pub(crate) struct Tx 83 | where 84 | S: Semaphore, 85 | { 86 | pub(crate) chan: Rc>, 87 | } 88 | 89 | #[derive(PartialEq, Eq, Clone, Copy, Debug)] 90 | pub enum SendError { 91 | RxClosed, 92 | } 93 | 94 | pub(crate) struct Rx 95 | where 96 | S: Semaphore, 97 | { 98 | chan: Rc>, 99 | } 100 | 101 | impl Tx 102 | where 103 | S: Semaphore, 104 | { 105 | pub(crate) fn new(chan: Rc>) -> Self { 106 | chan.tx_count.set(chan.tx_count.get() + 1); 107 | Self { chan } 108 | } 109 | 110 | // caller must make sure the chan has spaces 111 | pub(crate) fn send(&self, value: T) -> Result<(), SendError> { 112 | // check if the semaphore is closed 113 | if self.chan.semaphore.is_closed() { 114 | return Err(SendError::RxClosed); 115 | } 116 | 117 | // put data into the queue 118 | unsafe { 119 | self.chan.queue.borrow_mut().push_unchecked(value); 120 | } 121 | // if rx waker is set, wake it 122 | if let Some(w) = self.chan.rx_waker.replace(None) { 123 | w.wake(); 124 | } 125 | Ok(()) 126 | } 127 | 128 | pub fn is_closed(&self) -> bool { 129 | self.chan.semaphore.is_closed() 130 | } 131 | 132 | /// Returns `true` if senders belong to the same channel. 133 | pub(crate) fn same_channel(&self, other: &Self) -> bool { 134 | Rc::ptr_eq(&self.chan, &other.chan) 135 | } 136 | } 137 | 138 | impl Clone for Tx 139 | where 140 | S: Semaphore, 141 | { 142 | fn clone(&self) -> Self { 143 | self.chan.tx_count.set(self.chan.tx_count.get() + 1); 144 | Self { 145 | chan: self.chan.clone(), 146 | } 147 | } 148 | } 149 | 150 | impl Drop for Tx 151 | where 152 | S: Semaphore, 153 | { 154 | fn drop(&mut self) { 155 | let cnt = self.chan.tx_count.get(); 156 | self.chan.tx_count.set(cnt - 1); 157 | 158 | if cnt == 1 { 159 | self.chan.semaphore.close(); 160 | if let Some(rx_waker) = self.chan.rx_waker.take() { 161 | rx_waker.wake(); 162 | } 163 | } 164 | } 165 | } 166 | 167 | impl Rx 168 | where 169 | S: Semaphore, 170 | { 171 | pub(crate) fn new(chan: Rc>) -> Self { 172 | Self { chan } 173 | } 174 | 175 | pub(crate) fn try_recv(&mut self) -> Result { 176 | let mut queue = self.chan.queue.borrow_mut(); 177 | if !queue.is_empty() { 178 | let val = unsafe { queue.pop_unchecked() }; 179 | self.chan.semaphore.add_permits(1); 180 | return Ok(val); 181 | } 182 | if self.chan.tx_count.get() == 0 { 183 | Err(TryRecvError::Disconnected) 184 | } else { 185 | Err(TryRecvError::Empty) 186 | } 187 | } 188 | 189 | pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll> { 190 | let mut queue = self.chan.queue.borrow_mut(); 191 | if !queue.is_empty() { 192 | let val = unsafe { queue.pop_unchecked() }; 193 | self.chan.semaphore.add_permits(1); 194 | return Poll::Ready(Some(val)); 195 | } 196 | if self.chan.tx_count.get() == 0 { 197 | return Poll::Ready(None); 198 | } 199 | let mut borrowed = self.chan.rx_waker.borrow_mut(); 200 | match borrowed.as_mut() { 201 | Some(inner) => { 202 | if !inner.will_wake(cx.waker()) { 203 | *inner = cx.waker().clone(); 204 | } 205 | } 206 | None => { 207 | *borrowed = Some(cx.waker().clone()); 208 | } 209 | } 210 | Poll::Pending 211 | } 212 | 213 | pub(crate) fn close(&mut self) { 214 | self.chan.semaphore.close(); 215 | } 216 | } 217 | 218 | impl Drop for Rx 219 | where 220 | S: Semaphore, 221 | { 222 | fn drop(&mut self) { 223 | // close semaphore on close, this will make tx send await return. 224 | self.chan.semaphore.close(); 225 | // consume all elements 226 | let mut queue = self.chan.queue.borrow_mut(); 227 | let len = queue.len(); 228 | while !queue.is_empty() { 229 | drop(unsafe { queue.pop_unchecked() }); 230 | } 231 | self.chan.semaphore.add_permits(len); 232 | } 233 | } 234 | 235 | #[cfg(test)] 236 | mod tests { 237 | use super::channel; 238 | use crate::semaphore::Inner; 239 | use futures_util::future::poll_fn; 240 | 241 | #[monoio::test] 242 | async fn test_chan() { 243 | let semaphore = Inner::new(1); 244 | let (tx, mut rx) = channel::(semaphore); 245 | assert!(tx.send(1).is_ok()); 246 | assert_eq!(poll_fn(|cx| rx.recv(cx)).await, Some(1)); 247 | 248 | // close rx 249 | rx.close(); 250 | assert!(tx.is_closed()); 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 local-sync and Monoio Contributors 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /src/once_cell.rs: -------------------------------------------------------------------------------- 1 | //! Once Cell impl borrowed from tokio. 2 | 3 | use std::cell::{RefCell, UnsafeCell}; 4 | use std::error::Error; 5 | use std::fmt; 6 | use std::future::Future; 7 | use std::mem::MaybeUninit; 8 | use std::ops::Drop; 9 | use std::ptr; 10 | 11 | use crate::semaphore::{Semaphore, SemaphorePermit, TryAcquireError}; 12 | 13 | // This file contains an implementation of an OnceCell. The principle 14 | // behind the safety the of the cell is that any thread with an `&OnceCell` may 15 | // access the `value` field according the following rules: 16 | // 17 | // 1. When `value_set` is false, the `value` field may be modified by the 18 | // thread holding the permit on the semaphore. 19 | // 2. When `value_set` is true, the `value` field may be accessed immutably by 20 | // any thread. 21 | // 22 | // It is an invariant that if the semaphore is closed, then `value_set` is true. 23 | // The reverse does not necessarily hold — but if not, the semaphore may not 24 | // have any available permits. 25 | // 26 | // A thread with a `&mut OnceCell` may modify the value in any way it wants as 27 | // long as the invariants are upheld. 28 | 29 | /// A thread-safe cell that can be written to only once. 30 | /// 31 | /// A `OnceCell` is typically used for global variables that need to be 32 | /// initialized once on first use, but need no further changes. The `OnceCell` 33 | /// in Tokio allows the initialization procedure to be asynchronous. 34 | /// 35 | /// # Examples 36 | /// 37 | /// ``` 38 | /// use local_sync::OnceCell; 39 | /// 40 | /// async fn some_computation() -> u32 { 41 | /// 1 + 1 42 | /// } 43 | /// 44 | /// thread_local! { 45 | /// static ONCE: OnceCell = OnceCell::new(); 46 | /// } 47 | /// 48 | /// #[monoio::main] 49 | /// async fn main() { 50 | /// let once = ONCE.with(|once| unsafe { 51 | /// std::ptr::NonNull::new_unchecked(once as *const _ as *mut OnceCell).as_ref() 52 | /// }); 53 | /// let result = once.get_or_init(some_computation).await; 54 | /// assert_eq!(*result, 2); 55 | /// } 56 | /// ``` 57 | /// 58 | /// It is often useful to write a wrapper method for accessing the value. 59 | /// 60 | /// ``` 61 | /// use local_sync::OnceCell; 62 | /// 63 | /// thread_local! { 64 | /// static ONCE: OnceCell = OnceCell::new(); 65 | /// } 66 | /// 67 | /// async fn get_global_integer() -> &'static u32 { 68 | /// let once = ONCE.with(|once| unsafe { 69 | /// std::ptr::NonNull::new_unchecked(once as *const _ as *mut OnceCell).as_ref() 70 | /// }); 71 | /// once.get_or_init(|| async { 72 | /// 1 + 1 73 | /// }).await 74 | /// } 75 | /// 76 | /// #[monoio::main] 77 | /// async fn main() { 78 | /// let result = get_global_integer().await; 79 | /// assert_eq!(*result, 2); 80 | /// } 81 | /// ``` 82 | pub struct OnceCell { 83 | value_set: RefCell, 84 | value: UnsafeCell>, 85 | semaphore: Semaphore, 86 | } 87 | 88 | impl Default for OnceCell { 89 | fn default() -> OnceCell { 90 | OnceCell::new() 91 | } 92 | } 93 | 94 | impl fmt::Debug for OnceCell { 95 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 96 | fmt.debug_struct("OnceCell") 97 | .field("value", &self.get()) 98 | .finish() 99 | } 100 | } 101 | 102 | impl Clone for OnceCell { 103 | fn clone(&self) -> OnceCell { 104 | OnceCell::new_with(self.get().cloned()) 105 | } 106 | } 107 | 108 | impl PartialEq for OnceCell { 109 | fn eq(&self, other: &OnceCell) -> bool { 110 | self.get() == other.get() 111 | } 112 | } 113 | 114 | impl Eq for OnceCell {} 115 | 116 | impl Drop for OnceCell { 117 | fn drop(&mut self) { 118 | if self.initialized_mut() { 119 | unsafe { 120 | let ptr = self.value.get(); 121 | ptr::drop_in_place((&mut *ptr).as_mut_ptr()); 122 | }; 123 | } 124 | } 125 | } 126 | 127 | impl From for OnceCell { 128 | fn from(value: T) -> Self { 129 | let semaphore = Semaphore::new(0); 130 | semaphore.close(); 131 | OnceCell { 132 | value_set: RefCell::new(true), 133 | value: UnsafeCell::new(MaybeUninit::new(value)), 134 | semaphore, 135 | } 136 | } 137 | } 138 | 139 | impl OnceCell { 140 | /// Creates a new empty `OnceCell` instance. 141 | pub const fn new() -> Self { 142 | OnceCell { 143 | value_set: RefCell::new(false), 144 | value: UnsafeCell::new(MaybeUninit::uninit()), 145 | semaphore: Semaphore::new(1), 146 | } 147 | } 148 | 149 | /// Creates a new `OnceCell` that contains the provided value, if any. 150 | /// 151 | /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. 152 | /// 153 | /// [`OnceCell::new`]: crate::sync::OnceCell::new 154 | pub fn new_with(value: Option) -> Self { 155 | if let Some(v) = value { 156 | OnceCell::from(v) 157 | } else { 158 | OnceCell::new() 159 | } 160 | } 161 | 162 | /// Returns `true` if the `OnceCell` currently contains a value, and `false` 163 | /// otherwise. 164 | pub fn initialized(&self) -> bool { 165 | // Using acquire ordering so any threads that read a true from this 166 | // atomic is able to read the value. 167 | *self.value_set.borrow() 168 | } 169 | 170 | /// Returns `true` if the `OnceCell` currently contains a value, and `false` 171 | /// otherwise. 172 | fn initialized_mut(&mut self) -> bool { 173 | *self.value_set.get_mut() 174 | } 175 | 176 | // SAFETY: The OnceCell must not be empty. 177 | unsafe fn get_unchecked(&self) -> &T { 178 | let ptr = self.value.get(); 179 | &*(*ptr).as_ptr() 180 | } 181 | 182 | // SAFETY: The OnceCell must not be empty. 183 | unsafe fn get_unchecked_mut(&mut self) -> &mut T { 184 | let ptr = self.value.get(); 185 | &mut *(*ptr).as_mut_ptr() 186 | } 187 | 188 | fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { 189 | // SAFETY: We are holding the only permit on the semaphore. 190 | unsafe { 191 | let ptr = self.value.get(); 192 | (*ptr).as_mut_ptr().write(value); 193 | } 194 | 195 | // Using release ordering so any threads that read a true from this 196 | // atomic is able to read the value we just stored. 197 | *self.value_set.borrow_mut() = true; 198 | self.semaphore.close(); 199 | permit.forget(); 200 | 201 | // SAFETY: We just initialized the cell. 202 | unsafe { self.get_unchecked() } 203 | } 204 | 205 | /// Returns a reference to the value currently stored in the `OnceCell`, or 206 | /// `None` if the `OnceCell` is empty. 207 | pub fn get(&self) -> Option<&T> { 208 | if self.initialized() { 209 | Some(unsafe { self.get_unchecked() }) 210 | } else { 211 | None 212 | } 213 | } 214 | 215 | /// Returns a mutable reference to the value currently stored in the 216 | /// `OnceCell`, or `None` if the `OnceCell` is empty. 217 | /// 218 | /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the 219 | /// value inside the `OnceCell` — the mutable borrow statically guarantees 220 | /// no other references exist. 221 | pub fn get_mut(&mut self) -> Option<&mut T> { 222 | if self.initialized_mut() { 223 | Some(unsafe { self.get_unchecked_mut() }) 224 | } else { 225 | None 226 | } 227 | } 228 | 229 | /// Set the value of the `OnceCell` to the given value if the `OnceCell` is 230 | /// empty. 231 | /// 232 | /// If the `OnceCell` already has a value, this call will fail with an 233 | /// [`SetError::AlreadyInitializedError`]. 234 | /// 235 | /// If the `OnceCell` is empty, but some other task is currently trying to 236 | /// set the value, this call will fail with [`SetError::InitializingError`]. 237 | /// 238 | /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError 239 | /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError 240 | pub fn set(&self, value: T) -> Result<(), SetError> { 241 | if self.initialized() { 242 | return Err(SetError::AlreadyInitializedError(value)); 243 | } 244 | 245 | // Another task might be initializing the cell, in which case 246 | // `try_acquire` will return an error. If we succeed to acquire the 247 | // permit, then we can set the value. 248 | match self.semaphore.try_acquire() { 249 | Ok(permit) => { 250 | debug_assert!(!self.initialized()); 251 | self.set_value(value, permit); 252 | Ok(()) 253 | } 254 | Err(TryAcquireError::NoPermits) => { 255 | // Some other task is holding the permit. That task is 256 | // currently trying to initialize the value. 257 | Err(SetError::InitializingError(value)) 258 | } 259 | Err(TryAcquireError::Closed) => { 260 | // The semaphore was closed. Some other task has initialized 261 | // the value. 262 | Err(SetError::AlreadyInitializedError(value)) 263 | } 264 | } 265 | } 266 | 267 | /// Get the value currently in the `OnceCell`, or initialize it with the 268 | /// given asynchronous operation. 269 | /// 270 | /// If some other task is currently working on initializing the `OnceCell`, 271 | /// this call will wait for that other task to finish, then return the value 272 | /// that the other task produced. 273 | /// 274 | /// If the provided operation is cancelled or panics, the initialization 275 | /// attempt is cancelled. If there are other tasks waiting for the value to 276 | /// be initialized, one of them will start another attempt at initializing 277 | /// the value. 278 | /// 279 | /// This will deadlock if `f` tries to initialize the cell recursively. 280 | pub async fn get_or_init(&self, f: F) -> &T 281 | where 282 | F: FnOnce() -> Fut, 283 | Fut: Future, 284 | { 285 | if self.initialized() { 286 | // SAFETY: The OnceCell has been fully initialized. 287 | unsafe { self.get_unchecked() } 288 | } else { 289 | // Here we try to acquire the semaphore permit. Holding the permit 290 | // will allow us to set the value of the OnceCell, and prevents 291 | // other tasks from initializing the OnceCell while we are holding 292 | // it. 293 | match self.semaphore.acquire().await { 294 | Ok(permit) => { 295 | debug_assert!(!self.initialized()); 296 | 297 | // If `f()` panics or `select!` is called, this 298 | // `get_or_init` call is aborted and the semaphore permit is 299 | // dropped. 300 | let value = f().await; 301 | 302 | self.set_value(value, permit) 303 | } 304 | Err(_) => { 305 | debug_assert!(self.initialized()); 306 | 307 | // SAFETY: The semaphore has been closed. This only happens 308 | // when the OnceCell is fully initialized. 309 | unsafe { self.get_unchecked() } 310 | } 311 | } 312 | } 313 | } 314 | 315 | /// Get the value currently in the `OnceCell`, or initialize it with the 316 | /// given asynchronous operation. 317 | /// 318 | /// If some other task is currently working on initializing the `OnceCell`, 319 | /// this call will wait for that other task to finish, then return the value 320 | /// that the other task produced. 321 | /// 322 | /// If the provided operation returns an error, is cancelled or panics, the 323 | /// initialization attempt is cancelled. If there are other tasks waiting 324 | /// for the value to be initialized, one of them will start another attempt 325 | /// at initializing the value. 326 | /// 327 | /// This will deadlock if `f` tries to initialize the cell recursively. 328 | pub async fn get_or_try_init(&self, f: F) -> Result<&T, E> 329 | where 330 | F: FnOnce() -> Fut, 331 | Fut: Future>, 332 | { 333 | if self.initialized() { 334 | // SAFETY: The OnceCell has been fully initialized. 335 | unsafe { Ok(self.get_unchecked()) } 336 | } else { 337 | // Here we try to acquire the semaphore permit. Holding the permit 338 | // will allow us to set the value of the OnceCell, and prevents 339 | // other tasks from initializing the OnceCell while we are holding 340 | // it. 341 | match self.semaphore.acquire().await { 342 | Ok(permit) => { 343 | debug_assert!(!self.initialized()); 344 | 345 | // If `f()` panics or `select!` is called, this 346 | // `get_or_try_init` call is aborted and the semaphore 347 | // permit is dropped. 348 | let value = f().await; 349 | 350 | match value { 351 | Ok(value) => Ok(self.set_value(value, permit)), 352 | Err(e) => Err(e), 353 | } 354 | } 355 | Err(_) => { 356 | debug_assert!(self.initialized()); 357 | 358 | // SAFETY: The semaphore has been closed. This only happens 359 | // when the OnceCell is fully initialized. 360 | unsafe { Ok(self.get_unchecked()) } 361 | } 362 | } 363 | } 364 | } 365 | 366 | /// Take the value from the cell, destroying the cell in the process. 367 | /// Returns `None` if the cell is empty. 368 | pub fn into_inner(mut self) -> Option { 369 | if self.initialized_mut() { 370 | // Set to uninitialized for the destructor of `OnceCell` to work properly 371 | *self.value_set.get_mut() = false; 372 | 373 | let ptr = self.value.get(); 374 | Some(unsafe { ptr::read(ptr).assume_init() }) 375 | } else { 376 | None 377 | } 378 | } 379 | 380 | /// Takes ownership of the current value, leaving the cell empty. Returns 381 | /// `None` if the cell is empty. 382 | pub fn take(&mut self) -> Option { 383 | std::mem::take(self).into_inner() 384 | } 385 | } 386 | 387 | /// Errors that can be returned from [`OnceCell::set`]. 388 | /// 389 | /// [`OnceCell::set`]: crate::sync::OnceCell::set 390 | #[derive(Debug, PartialEq)] 391 | pub enum SetError { 392 | /// The cell was already initialized when [`OnceCell::set`] was called. 393 | /// 394 | /// [`OnceCell::set`]: crate::sync::OnceCell::set 395 | AlreadyInitializedError(T), 396 | 397 | /// The cell is currently being initialized. 398 | InitializingError(T), 399 | } 400 | 401 | impl fmt::Display for SetError { 402 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 403 | match self { 404 | SetError::AlreadyInitializedError(_) => write!(f, "AlreadyInitializedError"), 405 | SetError::InitializingError(_) => write!(f, "InitializingError"), 406 | } 407 | } 408 | } 409 | 410 | impl Error for SetError {} 411 | 412 | impl SetError { 413 | /// Whether `SetError` is `SetError::AlreadyInitializedError`. 414 | pub fn is_already_init_err(&self) -> bool { 415 | match self { 416 | SetError::AlreadyInitializedError(_) => true, 417 | SetError::InitializingError(_) => false, 418 | } 419 | } 420 | 421 | /// Whether `SetError` is `SetError::InitializingError` 422 | pub fn is_initializing_err(&self) -> bool { 423 | match self { 424 | SetError::AlreadyInitializedError(_) => false, 425 | SetError::InitializingError(_) => true, 426 | } 427 | } 428 | } 429 | 430 | #[cfg(test)] 431 | mod tests { 432 | use std::ptr::NonNull; 433 | 434 | use super::OnceCell; 435 | 436 | #[monoio::test] 437 | async fn test_once_cell_global() { 438 | thread_local! { 439 | static ONCE: OnceCell = OnceCell::new(); 440 | } 441 | async fn get_global_integer() -> &'static u32 { 442 | let once = ONCE.with(|once| unsafe { 443 | NonNull::new_unchecked(once as *const _ as *mut OnceCell).as_ref() 444 | }); 445 | once.get_or_init(|| async { 1 + 1 }).await 446 | } 447 | 448 | assert_eq!(*get_global_integer().await, 2); 449 | assert_eq!(*get_global_integer().await, 2); 450 | } 451 | 452 | #[monoio::test] 453 | async fn test_once_cell() { 454 | let once: OnceCell = OnceCell::new(); 455 | assert_eq!(once.get_or_init(|| async { 1 + 1 }).await, &2); 456 | assert_eq!(once.get_or_init(|| async { 1 + 2 }).await, &2); 457 | } 458 | } 459 | -------------------------------------------------------------------------------- /src/linked_list.rs: -------------------------------------------------------------------------------- 1 | //! Linked list borrowed from tokio. 2 | //! 3 | //! An intrusive double linked list of data 4 | //! 5 | //! The data structure supports tracking pinned nodes. Most of the data 6 | //! structure's APIs are `unsafe` as they require the caller to ensure the 7 | //! specified node is actually contained by the list. 8 | #![allow(unused)] 9 | 10 | use core::cell::UnsafeCell; 11 | use core::fmt; 12 | use core::marker::{PhantomData, PhantomPinned}; 13 | use core::mem::ManuallyDrop; 14 | use core::ptr::{self, NonNull}; 15 | 16 | /// An intrusive linked list. 17 | /// 18 | /// Currently, the list is not emptied on drop. It is the caller's 19 | /// responsibility to ensure the list is empty before dropping it. 20 | pub(crate) struct LinkedList { 21 | /// Linked list head 22 | head: Option>, 23 | 24 | /// Linked list tail 25 | tail: Option>, 26 | 27 | /// Node type marker. 28 | _marker: PhantomData<*const L>, 29 | } 30 | 31 | unsafe impl Send for LinkedList where L::Target: Send {} 32 | unsafe impl Sync for LinkedList where L::Target: Sync {} 33 | 34 | /// Defines how a type is tracked within a linked list. 35 | /// 36 | /// In order to support storing a single type within multiple lists, accessing 37 | /// the list pointers is decoupled from the entry type. 38 | /// 39 | /// # Safety 40 | /// 41 | /// Implementations must guarantee that `Target` types are pinned in memory. In 42 | /// other words, when a node is inserted, the value will not be moved as long as 43 | /// it is stored in the list. 44 | pub(crate) unsafe trait Link { 45 | /// Handle to the list entry. 46 | /// 47 | /// This is usually a pointer-ish type. 48 | type Handle; 49 | 50 | /// Node type 51 | type Target; 52 | 53 | /// Convert the handle to a raw pointer without consuming the handle 54 | #[allow(clippy::wrong_self_convention)] 55 | fn as_raw(handle: &Self::Handle) -> NonNull; 56 | 57 | /// Convert the raw pointer to a handle 58 | unsafe fn from_raw(ptr: NonNull) -> Self::Handle; 59 | 60 | /// Return the pointers for a node 61 | unsafe fn pointers(target: NonNull) -> NonNull>; 62 | } 63 | 64 | /// Previous / next pointers 65 | pub(crate) struct Pointers { 66 | inner: UnsafeCell>, 67 | } 68 | /// We do not want the compiler to put the `noalias` attribute on mutable 69 | /// references to this type, so the type has been made `!Unpin` with a 70 | /// `PhantomPinned` field. 71 | /// 72 | /// Additionally, we never access the `prev` or `next` fields directly, as any 73 | /// such access would implicitly involve the creation of a reference to the 74 | /// field, which we want to avoid since the fields are not `!Unpin`, and would 75 | /// hence be given the `noalias` attribute if we were to do such an access. 76 | /// As an alternative to accessing the fields directly, the `Pointers` type 77 | /// provides getters and setters for the two fields, and those are implemented 78 | /// using raw pointer casts and offsets, which is valid since the struct is 79 | /// #[repr(C)]. 80 | /// 81 | /// See this link for more information: 82 | /// 83 | #[repr(C)] 84 | struct PointersInner { 85 | /// The previous node in the list. null if there is no previous node. 86 | /// 87 | /// This field is accessed through pointer manipulation, so it is not dead code. 88 | #[allow(dead_code)] 89 | prev: Option>, 90 | 91 | /// The next node in the list. null if there is no previous node. 92 | /// 93 | /// This field is accessed through pointer manipulation, so it is not dead code. 94 | #[allow(dead_code)] 95 | next: Option>, 96 | 97 | /// This type is !Unpin due to the heuristic from: 98 | /// 99 | _pin: PhantomPinned, 100 | } 101 | 102 | unsafe impl Send for Pointers {} 103 | unsafe impl Sync for Pointers {} 104 | 105 | // ===== impl LinkedList ===== 106 | 107 | impl LinkedList { 108 | /// Creates an empty linked list. 109 | pub(crate) const fn new() -> LinkedList { 110 | LinkedList { 111 | head: None, 112 | tail: None, 113 | _marker: PhantomData, 114 | } 115 | } 116 | } 117 | 118 | impl LinkedList { 119 | /// Adds an element first in the list. 120 | pub(crate) fn push_front(&mut self, val: L::Handle) { 121 | // The value should not be dropped, it is being inserted into the list 122 | let val = ManuallyDrop::new(val); 123 | let ptr = L::as_raw(&*val); 124 | assert_ne!(self.head, Some(ptr)); 125 | unsafe { 126 | L::pointers(ptr).as_mut().set_next(self.head); 127 | L::pointers(ptr).as_mut().set_prev(None); 128 | 129 | if let Some(head) = self.head { 130 | L::pointers(head).as_mut().set_prev(Some(ptr)); 131 | } 132 | 133 | self.head = Some(ptr); 134 | 135 | if self.tail.is_none() { 136 | self.tail = Some(ptr); 137 | } 138 | } 139 | } 140 | 141 | /// Removes the last element from a list and returns it, or None if it is 142 | /// empty. 143 | pub(crate) fn pop_back(&mut self) -> Option { 144 | unsafe { 145 | let last = self.tail?; 146 | self.tail = L::pointers(last).as_ref().get_prev(); 147 | 148 | if let Some(prev) = L::pointers(last).as_ref().get_prev() { 149 | L::pointers(prev).as_mut().set_next(None); 150 | } else { 151 | self.head = None 152 | } 153 | 154 | L::pointers(last).as_mut().set_prev(None); 155 | L::pointers(last).as_mut().set_next(None); 156 | 157 | Some(L::from_raw(last)) 158 | } 159 | } 160 | 161 | /// Returns whether the linked list does not contain any node 162 | pub(crate) fn is_empty(&self) -> bool { 163 | if self.head.is_some() { 164 | return false; 165 | } 166 | 167 | assert!(self.tail.is_none()); 168 | true 169 | } 170 | 171 | /// Removes the specified node from the list 172 | /// 173 | /// # Safety 174 | /// 175 | /// The caller **must** ensure that `node` is currently contained by 176 | /// `self` or not contained by any other list. 177 | pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { 178 | if let Some(prev) = L::pointers(node).as_ref().get_prev() { 179 | debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); 180 | L::pointers(prev) 181 | .as_mut() 182 | .set_next(L::pointers(node).as_ref().get_next()); 183 | } else { 184 | if self.head != Some(node) { 185 | return None; 186 | } 187 | 188 | self.head = L::pointers(node).as_ref().get_next(); 189 | } 190 | 191 | if let Some(next) = L::pointers(node).as_ref().get_next() { 192 | debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); 193 | L::pointers(next) 194 | .as_mut() 195 | .set_prev(L::pointers(node).as_ref().get_prev()); 196 | } else { 197 | // This might be the last item in the list 198 | if self.tail != Some(node) { 199 | return None; 200 | } 201 | 202 | self.tail = L::pointers(node).as_ref().get_prev(); 203 | } 204 | 205 | L::pointers(node).as_mut().set_next(None); 206 | L::pointers(node).as_mut().set_prev(None); 207 | 208 | Some(L::from_raw(node)) 209 | } 210 | } 211 | 212 | impl fmt::Debug for LinkedList { 213 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 214 | f.debug_struct("LinkedList") 215 | .field("head", &self.head) 216 | .field("tail", &self.tail) 217 | .finish() 218 | } 219 | } 220 | 221 | impl LinkedList { 222 | pub(crate) fn last(&self) -> Option<&L::Target> { 223 | let tail = self.tail.as_ref()?; 224 | unsafe { Some(&*tail.as_ptr()) } 225 | } 226 | } 227 | 228 | impl Default for LinkedList { 229 | fn default() -> Self { 230 | Self::new() 231 | } 232 | } 233 | 234 | // ===== impl Pointers ===== 235 | 236 | impl Pointers { 237 | /// Create a new set of empty pointers 238 | pub(crate) fn new() -> Pointers { 239 | Pointers { 240 | inner: UnsafeCell::new(PointersInner { 241 | prev: None, 242 | next: None, 243 | _pin: PhantomPinned, 244 | }), 245 | } 246 | } 247 | 248 | fn get_prev(&self) -> Option> { 249 | // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. 250 | unsafe { 251 | let inner = self.inner.get(); 252 | let prev = inner as *const Option>; 253 | ptr::read(prev) 254 | } 255 | } 256 | fn get_next(&self) -> Option> { 257 | // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. 258 | unsafe { 259 | let inner = self.inner.get(); 260 | let prev = inner as *const Option>; 261 | let next = prev.add(1); 262 | ptr::read(next) 263 | } 264 | } 265 | 266 | fn set_prev(&mut self, value: Option>) { 267 | // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. 268 | unsafe { 269 | let inner = self.inner.get(); 270 | let prev = inner as *mut Option>; 271 | ptr::write(prev, value); 272 | } 273 | } 274 | fn set_next(&mut self, value: Option>) { 275 | // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. 276 | unsafe { 277 | let inner = self.inner.get(); 278 | let prev = inner as *mut Option>; 279 | let next = prev.add(1); 280 | ptr::write(next, value); 281 | } 282 | } 283 | } 284 | 285 | impl fmt::Debug for Pointers { 286 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 287 | let prev = self.get_prev(); 288 | let next = self.get_next(); 289 | f.debug_struct("Pointers") 290 | .field("prev", &prev) 291 | .field("next", &next) 292 | .finish() 293 | } 294 | } 295 | 296 | #[cfg(test)] 297 | mod tests { 298 | use super::*; 299 | 300 | use std::pin::Pin; 301 | 302 | #[derive(Debug)] 303 | struct Entry { 304 | pointers: Pointers, 305 | val: i32, 306 | } 307 | 308 | unsafe impl<'a> Link for &'a Entry { 309 | type Handle = Pin<&'a Entry>; 310 | type Target = Entry; 311 | 312 | fn as_raw(handle: &Pin<&'_ Entry>) -> NonNull { 313 | NonNull::from(handle.get_ref()) 314 | } 315 | 316 | unsafe fn from_raw(ptr: NonNull) -> Pin<&'a Entry> { 317 | Pin::new_unchecked(&*ptr.as_ptr()) 318 | } 319 | 320 | unsafe fn pointers(mut target: NonNull) -> NonNull> { 321 | NonNull::from(&mut target.as_mut().pointers) 322 | } 323 | } 324 | 325 | fn entry(val: i32) -> Pin> { 326 | Box::pin(Entry { 327 | pointers: Pointers::new(), 328 | val, 329 | }) 330 | } 331 | 332 | fn ptr(r: &Pin>) -> NonNull { 333 | r.as_ref().get_ref().into() 334 | } 335 | 336 | fn collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec { 337 | let mut ret = vec![]; 338 | 339 | while let Some(entry) = list.pop_back() { 340 | ret.push(entry.val); 341 | } 342 | 343 | ret 344 | } 345 | 346 | fn push_all<'a>( 347 | list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>, 348 | entries: &[Pin<&'a Entry>], 349 | ) { 350 | for entry in entries.iter() { 351 | list.push_front(*entry); 352 | } 353 | } 354 | 355 | macro_rules! assert_clean { 356 | ($e:ident) => {{ 357 | assert!($e.pointers.get_next().is_none()); 358 | assert!($e.pointers.get_prev().is_none()); 359 | }}; 360 | } 361 | 362 | macro_rules! assert_ptr_eq { 363 | ($a:expr, $b:expr) => {{ 364 | // Deal with mapping a Pin<&mut T> -> Option> 365 | assert_eq!(Some($a.as_ref().get_ref().into()), $b) 366 | }}; 367 | } 368 | 369 | #[test] 370 | fn const_new() { 371 | const _: LinkedList<&Entry, <&Entry as Link>::Target> = LinkedList::new(); 372 | } 373 | 374 | #[test] 375 | fn push_and_drain() { 376 | let a = entry(5); 377 | let b = entry(7); 378 | let c = entry(31); 379 | 380 | let mut list = LinkedList::new(); 381 | assert!(list.is_empty()); 382 | 383 | list.push_front(a.as_ref()); 384 | assert!(!list.is_empty()); 385 | list.push_front(b.as_ref()); 386 | list.push_front(c.as_ref()); 387 | 388 | let items: Vec = collect_list(&mut list); 389 | assert_eq!([5, 7, 31].to_vec(), items); 390 | 391 | assert!(list.is_empty()); 392 | } 393 | 394 | #[test] 395 | fn push_pop_push_pop() { 396 | let a = entry(5); 397 | let b = entry(7); 398 | 399 | let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); 400 | 401 | list.push_front(a.as_ref()); 402 | 403 | let entry = list.pop_back().unwrap(); 404 | assert_eq!(5, entry.val); 405 | assert!(list.is_empty()); 406 | 407 | list.push_front(b.as_ref()); 408 | 409 | let entry = list.pop_back().unwrap(); 410 | assert_eq!(7, entry.val); 411 | 412 | assert!(list.is_empty()); 413 | assert!(list.pop_back().is_none()); 414 | } 415 | 416 | #[test] 417 | fn remove_by_address() { 418 | let a = entry(5); 419 | let b = entry(7); 420 | let c = entry(31); 421 | 422 | unsafe { 423 | // Remove first 424 | let mut list = LinkedList::new(); 425 | 426 | push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); 427 | assert!(list.remove(ptr(&a)).is_some()); 428 | assert_clean!(a); 429 | // `a` should be no longer there and can't be removed twice 430 | assert!(list.remove(ptr(&a)).is_none()); 431 | assert!(!list.is_empty()); 432 | 433 | assert!(list.remove(ptr(&b)).is_some()); 434 | assert_clean!(b); 435 | // `b` should be no longer there and can't be removed twice 436 | assert!(list.remove(ptr(&b)).is_none()); 437 | assert!(!list.is_empty()); 438 | 439 | assert!(list.remove(ptr(&c)).is_some()); 440 | assert_clean!(c); 441 | // `b` should be no longer there and can't be removed twice 442 | assert!(list.remove(ptr(&c)).is_none()); 443 | assert!(list.is_empty()); 444 | } 445 | 446 | unsafe { 447 | // Remove middle 448 | let mut list = LinkedList::new(); 449 | 450 | push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); 451 | 452 | assert!(list.remove(ptr(&a)).is_some()); 453 | assert_clean!(a); 454 | 455 | assert_ptr_eq!(b, list.head); 456 | assert_ptr_eq!(c, b.pointers.get_next()); 457 | assert_ptr_eq!(b, c.pointers.get_prev()); 458 | 459 | let items = collect_list(&mut list); 460 | assert_eq!([31, 7].to_vec(), items); 461 | } 462 | 463 | unsafe { 464 | // Remove middle 465 | let mut list = LinkedList::new(); 466 | 467 | push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); 468 | 469 | assert!(list.remove(ptr(&b)).is_some()); 470 | assert_clean!(b); 471 | 472 | assert_ptr_eq!(c, a.pointers.get_next()); 473 | assert_ptr_eq!(a, c.pointers.get_prev()); 474 | 475 | let items = collect_list(&mut list); 476 | assert_eq!([31, 5].to_vec(), items); 477 | } 478 | 479 | unsafe { 480 | // Remove last 481 | // Remove middle 482 | let mut list = LinkedList::new(); 483 | 484 | push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); 485 | 486 | assert!(list.remove(ptr(&c)).is_some()); 487 | assert_clean!(c); 488 | 489 | assert!(b.pointers.get_next().is_none()); 490 | assert_ptr_eq!(b, list.tail); 491 | 492 | let items = collect_list(&mut list); 493 | assert_eq!([7, 5].to_vec(), items); 494 | } 495 | 496 | unsafe { 497 | // Remove first of two 498 | let mut list = LinkedList::new(); 499 | 500 | push_all(&mut list, &[b.as_ref(), a.as_ref()]); 501 | 502 | assert!(list.remove(ptr(&a)).is_some()); 503 | 504 | assert_clean!(a); 505 | 506 | // a should be no longer there and can't be removed twice 507 | assert!(list.remove(ptr(&a)).is_none()); 508 | 509 | assert_ptr_eq!(b, list.head); 510 | assert_ptr_eq!(b, list.tail); 511 | 512 | assert!(b.pointers.get_next().is_none()); 513 | assert!(b.pointers.get_prev().is_none()); 514 | 515 | let items = collect_list(&mut list); 516 | assert_eq!([7].to_vec(), items); 517 | } 518 | 519 | unsafe { 520 | // Remove last of two 521 | let mut list = LinkedList::new(); 522 | 523 | push_all(&mut list, &[b.as_ref(), a.as_ref()]); 524 | 525 | assert!(list.remove(ptr(&b)).is_some()); 526 | 527 | assert_clean!(b); 528 | 529 | assert_ptr_eq!(a, list.head); 530 | assert_ptr_eq!(a, list.tail); 531 | 532 | assert!(a.pointers.get_next().is_none()); 533 | assert!(a.pointers.get_prev().is_none()); 534 | 535 | let items = collect_list(&mut list); 536 | assert_eq!([5].to_vec(), items); 537 | } 538 | 539 | unsafe { 540 | // Remove last item 541 | let mut list = LinkedList::new(); 542 | 543 | push_all(&mut list, &[a.as_ref()]); 544 | 545 | assert!(list.remove(ptr(&a)).is_some()); 546 | assert_clean!(a); 547 | 548 | assert!(list.head.is_none()); 549 | assert!(list.tail.is_none()); 550 | let items = collect_list(&mut list); 551 | assert!(items.is_empty()); 552 | } 553 | 554 | unsafe { 555 | // Remove missing 556 | let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); 557 | 558 | list.push_front(b.as_ref()); 559 | list.push_front(a.as_ref()); 560 | 561 | assert!(list.remove(ptr(&c)).is_none()); 562 | } 563 | } 564 | } 565 | -------------------------------------------------------------------------------- /src/oneshot.rs: -------------------------------------------------------------------------------- 1 | //! Oneshot borrowed from tokio. 2 | //! 3 | //! A one-shot channel is used for sending a single message between 4 | //! asynchronous tasks. The [`channel`] function is used to create a 5 | //! [`Sender`] and [`Receiver`] handle pair that form the channel. 6 | //! 7 | //! The `Sender` handle is used by the producer to send the value. 8 | //! The `Receiver` handle is used by the consumer to receive the value. 9 | //! 10 | //! Each handle can be used on separate tasks. 11 | //! 12 | //! # Examples 13 | //! 14 | //! ``` 15 | //! use local_sync::oneshot; 16 | //! 17 | //! #[monoio::main] 18 | //! async fn main() { 19 | //! let (tx, rx) = oneshot::channel(); 20 | //! 21 | //! monoio::spawn(async move { 22 | //! if let Err(_) = tx.send(3) { 23 | //! println!("the receiver dropped"); 24 | //! } 25 | //! }); 26 | //! 27 | //! match rx.await { 28 | //! Ok(v) => println!("got = {:?}", v), 29 | //! Err(_) => println!("the sender dropped"), 30 | //! } 31 | //! } 32 | //! ``` 33 | //! 34 | //! If the sender is dropped without sending, the receiver will fail with 35 | //! [`error::RecvError`]: 36 | //! 37 | //! ``` 38 | //! use local_sync::oneshot; 39 | //! 40 | //! #[monoio::main] 41 | //! async fn main() { 42 | //! let (tx, rx) = oneshot::channel::(); 43 | //! 44 | //! monoio::spawn(async move { 45 | //! drop(tx); 46 | //! }); 47 | //! 48 | //! match rx.await { 49 | //! Ok(_) => panic!("This doesn't happen"), 50 | //! Err(_) => println!("the sender dropped"), 51 | //! } 52 | //! } 53 | //! ``` 54 | 55 | use std::cell::{RefCell, UnsafeCell}; 56 | use std::fmt; 57 | use std::future::Future; 58 | use std::mem::MaybeUninit; 59 | use std::pin::Pin; 60 | use std::rc::Rc; 61 | use std::task::Poll::{Pending, Ready}; 62 | use std::task::{Context, Poll, Waker}; 63 | 64 | /// Sends a value to the associated [`Receiver`]. 65 | /// 66 | /// A pair of both a [`Sender`] and a [`Receiver`] are created by the 67 | /// [`channel`](fn@channel) function. 68 | #[derive(Debug)] 69 | pub struct Sender { 70 | inner: Option>>, 71 | } 72 | 73 | /// Receive a value from the associated [`Sender`]. 74 | /// 75 | /// A pair of both a [`Sender`] and a [`Receiver`] are created by the 76 | /// [`channel`](fn@channel) function. 77 | /// 78 | /// # Examples 79 | /// 80 | /// ``` 81 | /// use local_sync::oneshot; 82 | /// 83 | /// #[monoio::main] 84 | /// async fn main() { 85 | /// let (tx, rx) = oneshot::channel(); 86 | /// 87 | /// monoio::spawn(async move { 88 | /// if let Err(_) = tx.send(3) { 89 | /// println!("the receiver dropped"); 90 | /// } 91 | /// }); 92 | /// 93 | /// match rx.await { 94 | /// Ok(v) => println!("got = {:?}", v), 95 | /// Err(_) => println!("the sender dropped"), 96 | /// } 97 | /// } 98 | /// ``` 99 | /// 100 | /// If the sender is dropped without sending, the receiver will fail with 101 | /// [`error::RecvError`]: 102 | /// 103 | /// ``` 104 | /// use local_sync::oneshot; 105 | /// 106 | /// #[monoio::main] 107 | /// async fn main() { 108 | /// let (tx, rx) = oneshot::channel::(); 109 | /// 110 | /// monoio::spawn(async move { 111 | /// drop(tx); 112 | /// }); 113 | /// 114 | /// match rx.await { 115 | /// Ok(_) => panic!("This doesn't happen"), 116 | /// Err(_) => println!("the sender dropped"), 117 | /// } 118 | /// } 119 | /// ``` 120 | #[derive(Debug)] 121 | pub struct Receiver { 122 | inner: Option>>, 123 | } 124 | 125 | pub mod error { 126 | //! Oneshot error types 127 | 128 | use std::fmt; 129 | 130 | /// Error returned by the `Future` implementation for `Receiver`. 131 | #[derive(Debug, Eq, PartialEq)] 132 | pub struct RecvError(pub(super) ()); 133 | 134 | /// Error returned by the `try_recv` function on `Receiver`. 135 | #[derive(Debug, Eq, PartialEq)] 136 | pub enum TryRecvError { 137 | /// The send half of the channel has not yet sent a value. 138 | Empty, 139 | 140 | /// The send half of the channel was dropped without sending a value. 141 | Closed, 142 | } 143 | 144 | // ===== impl RecvError ===== 145 | 146 | impl fmt::Display for RecvError { 147 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 148 | write!(fmt, "channel closed") 149 | } 150 | } 151 | 152 | impl std::error::Error for RecvError {} 153 | 154 | // ===== impl TryRecvError ===== 155 | 156 | impl fmt::Display for TryRecvError { 157 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 158 | match self { 159 | TryRecvError::Empty => write!(fmt, "channel empty"), 160 | TryRecvError::Closed => write!(fmt, "channel closed"), 161 | } 162 | } 163 | } 164 | 165 | impl std::error::Error for TryRecvError {} 166 | } 167 | 168 | use futures_core::ready; 169 | 170 | use self::error::*; 171 | 172 | struct Inner { 173 | /// Manages the state of the inner cell 174 | state: RefCell, 175 | 176 | /// The value. This is set by `Sender` and read by `Receiver`. The state of 177 | /// the cell is tracked by `state`. 178 | value: UnsafeCell>, 179 | 180 | /// The task to notify when the receiver drops without consuming the value. 181 | tx_task: Task, 182 | 183 | /// The task to notify when the value is sent. 184 | rx_task: Task, 185 | } 186 | 187 | struct Task(UnsafeCell>); 188 | 189 | impl Task { 190 | unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool { 191 | self.with_task(|w| w.will_wake(cx.waker())) 192 | } 193 | 194 | unsafe fn with_task(&self, f: F) -> R 195 | where 196 | F: FnOnce(&Waker) -> R, 197 | { 198 | let ptr = self.0.get(); 199 | let waker: *const Waker = (&*ptr).as_ptr(); 200 | f(&*waker) 201 | } 202 | 203 | unsafe fn drop_task(&self) { 204 | let ptr: *mut Waker = (&mut *self.0.get()).as_mut_ptr(); 205 | ptr.drop_in_place(); 206 | } 207 | 208 | unsafe fn set_task(&self, cx: &mut Context<'_>) { 209 | let ptr: *mut Waker = (&mut *self.0.get()).as_mut_ptr(); 210 | ptr.write(cx.waker().clone()); 211 | } 212 | } 213 | 214 | #[derive(Clone, Copy)] 215 | struct State(usize); 216 | 217 | /// Create a new one-shot channel for sending single values across asynchronous 218 | /// tasks. 219 | /// 220 | /// The function returns separate "send" and "receive" handles. The `Sender` 221 | /// handle is used by the producer to send the value. The `Receiver` handle is 222 | /// used by the consumer to receive the value. 223 | /// 224 | /// Each handle can be used on separate tasks. 225 | /// 226 | /// # Examples 227 | /// 228 | /// ``` 229 | /// use local_sync::oneshot; 230 | /// 231 | /// #[monoio::main] 232 | /// async fn main() { 233 | /// let (tx, rx) = oneshot::channel(); 234 | /// 235 | /// monoio::spawn(async move { 236 | /// if let Err(_) = tx.send(3) { 237 | /// println!("the receiver dropped"); 238 | /// } 239 | /// }); 240 | /// 241 | /// match rx.await { 242 | /// Ok(v) => println!("got = {:?}", v), 243 | /// Err(_) => println!("the sender dropped"), 244 | /// } 245 | /// } 246 | /// ``` 247 | pub fn channel() -> (Sender, Receiver) { 248 | let inner = Rc::new(Inner { 249 | state: RefCell::new(State::new().as_usize()), 250 | value: UnsafeCell::new(None), 251 | tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), 252 | rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), 253 | }); 254 | 255 | let tx = Sender { 256 | inner: Some(inner.clone()), 257 | }; 258 | let rx = Receiver { inner: Some(inner) }; 259 | 260 | (tx, rx) 261 | } 262 | 263 | impl Sender { 264 | /// Attempts to send a value on this channel, returning it back if it could 265 | /// not be sent. 266 | /// 267 | /// This method consumes `self` as only one value may ever be sent on a oneshot 268 | /// channel. It is not marked async because sending a message to an oneshot 269 | /// channel never requires any form of waiting. Because of this, the `send` 270 | /// method can be used in both synchronous and asynchronous code without 271 | /// problems. 272 | /// 273 | /// A successful send occurs when it is determined that the other end of the 274 | /// channel has not hung up already. An unsuccessful send would be one where 275 | /// the corresponding receiver has already been deallocated. Note that a 276 | /// return value of `Err` means that the data will never be received, but 277 | /// a return value of `Ok` does *not* mean that the data will be received. 278 | /// It is possible for the corresponding receiver to hang up immediately 279 | /// after this function returns `Ok`. 280 | /// 281 | /// # Examples 282 | /// 283 | /// Send a value to another task 284 | /// 285 | /// ``` 286 | /// use local_sync::oneshot; 287 | /// 288 | /// #[monoio::main] 289 | /// async fn main() { 290 | /// let (tx, rx) = oneshot::channel(); 291 | /// 292 | /// monoio::spawn(async move { 293 | /// if let Err(_) = tx.send(3) { 294 | /// println!("the receiver dropped"); 295 | /// } 296 | /// }); 297 | /// 298 | /// match rx.await { 299 | /// Ok(v) => println!("got = {:?}", v), 300 | /// Err(_) => println!("the sender dropped"), 301 | /// } 302 | /// } 303 | /// ``` 304 | pub fn send(mut self, t: T) -> Result<(), T> { 305 | let inner = self.inner.take().unwrap(); 306 | let ptr = inner.value.get(); 307 | unsafe { 308 | *ptr = Some(t); 309 | } 310 | 311 | if !inner.complete() { 312 | unsafe { 313 | return Err(inner.consume_value().unwrap()); 314 | } 315 | } 316 | 317 | Ok(()) 318 | } 319 | 320 | /// Waits for the associated [`Receiver`] handle to close. 321 | /// 322 | /// A [`Receiver`] is closed by either calling [`close`] explicitly or the 323 | /// [`Receiver`] value is dropped. 324 | /// 325 | /// This function is useful when paired with `select!` to abort a 326 | /// computation when the receiver is no longer interested in the result. 327 | /// 328 | /// # Return 329 | /// 330 | /// Returns a `Future` which must be awaited on. 331 | /// 332 | /// [`Receiver`]: Receiver 333 | /// [`close`]: Receiver::close 334 | /// 335 | /// # Examples 336 | /// 337 | /// Basic usage 338 | /// 339 | /// ``` 340 | /// use local_sync::oneshot; 341 | /// 342 | /// #[monoio::main] 343 | /// async fn main() { 344 | /// let (mut tx, rx) = oneshot::channel::<()>(); 345 | /// 346 | /// monoio::spawn(async move { 347 | /// drop(rx); 348 | /// }); 349 | /// 350 | /// tx.closed().await; 351 | /// println!("the receiver dropped"); 352 | /// } 353 | /// ``` 354 | /// 355 | /// Paired with select 356 | /// 357 | /// ``` 358 | /// use local_sync::oneshot; 359 | /// use monoio::time::{self, Duration}; 360 | /// 361 | /// async fn compute() -> String { 362 | /// // Complex computation returning a `String` 363 | /// # "hello".to_string() 364 | /// } 365 | /// 366 | /// #[monoio::main] 367 | /// async fn main() { 368 | /// let (mut tx, rx) = oneshot::channel(); 369 | /// 370 | /// monoio::spawn(async move { 371 | /// monoio::select! { 372 | /// _ = tx.closed() => { 373 | /// // The receiver dropped, no need to do any further work 374 | /// } 375 | /// value = compute() => { 376 | /// // The send can fail if the channel was closed at the exact same 377 | /// // time as when compute() finished, so just ignore the failure. 378 | /// let _ = tx.send(value); 379 | /// } 380 | /// } 381 | /// }); 382 | /// 383 | /// // Wait for up to 10 seconds 384 | /// let _ = time::timeout(Duration::from_secs(10), rx).await; 385 | /// } 386 | /// ``` 387 | pub async fn closed(&mut self) { 388 | use futures_util::future::poll_fn; 389 | 390 | poll_fn(|cx| self.poll_closed(cx)).await 391 | } 392 | 393 | /// Returns `true` if the associated [`Receiver`] handle has been dropped. 394 | /// 395 | /// A [`Receiver`] is closed by either calling [`close`] explicitly or the 396 | /// [`Receiver`] value is dropped. 397 | /// 398 | /// If `true` is returned, a call to `send` will always result in an error. 399 | /// 400 | /// [`Receiver`]: Receiver 401 | /// [`close`]: Receiver::close 402 | /// 403 | /// # Examples 404 | /// 405 | /// ``` 406 | /// use local_sync::oneshot; 407 | /// 408 | /// #[monoio::main] 409 | /// async fn main() { 410 | /// let (tx, rx) = oneshot::channel(); 411 | /// 412 | /// assert!(!tx.is_closed()); 413 | /// 414 | /// drop(rx); 415 | /// 416 | /// assert!(tx.is_closed()); 417 | /// assert!(tx.send("never received").is_err()); 418 | /// } 419 | /// ``` 420 | pub fn is_closed(&self) -> bool { 421 | let inner = self.inner.as_ref().unwrap(); 422 | 423 | let state = State(*inner.state.borrow()); 424 | state.is_closed() 425 | } 426 | 427 | /// Check whether the oneshot channel has been closed, and if not, schedules the 428 | /// `Waker` in the provided `Context` to receive a notification when the channel is 429 | /// closed. 430 | /// 431 | /// A [`Receiver`] is closed by either calling [`close`] explicitly, or when the 432 | /// [`Receiver`] value is dropped. 433 | /// 434 | /// Note that on multiple calls to poll, only the `Waker` from the `Context` passed 435 | /// to the most recent call will be scheduled to receive a wakeup. 436 | /// 437 | /// [`Receiver`]: struct@crate::sync::oneshot::Receiver 438 | /// [`close`]: fn@crate::sync::oneshot::Receiver::close 439 | /// 440 | /// # Return value 441 | /// 442 | /// This function returns: 443 | /// 444 | /// * `Poll::Pending` if the channel is still open. 445 | /// * `Poll::Ready(())` if the channel is closed. 446 | /// 447 | /// # Examples 448 | /// 449 | /// ``` 450 | /// use local_sync::oneshot; 451 | /// 452 | /// use futures_util::future::poll_fn; 453 | /// 454 | /// #[monoio::main] 455 | /// async fn main() { 456 | /// let (mut tx, mut rx) = oneshot::channel::<()>(); 457 | /// 458 | /// monoio::spawn(async move { 459 | /// rx.close(); 460 | /// }); 461 | /// 462 | /// poll_fn(|cx| tx.poll_closed(cx)).await; 463 | /// 464 | /// println!("the receiver dropped"); 465 | /// } 466 | /// ``` 467 | pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { 468 | let inner = self.inner.as_ref().unwrap(); 469 | 470 | let mut state = State(*inner.state.borrow()); 471 | 472 | if state.is_closed() { 473 | return Poll::Ready(()); 474 | } 475 | 476 | if state.is_tx_task_set() { 477 | let will_notify = unsafe { inner.tx_task.will_wake(cx) }; 478 | 479 | if !will_notify { 480 | state = State::unset_tx_task(&inner.state); 481 | 482 | if state.is_closed() { 483 | // Set the flag again so that the waker is released in drop 484 | State::set_tx_task(&inner.state); 485 | return Ready(()); 486 | } else { 487 | unsafe { inner.tx_task.drop_task() }; 488 | } 489 | } 490 | } 491 | 492 | if !state.is_tx_task_set() { 493 | // Attempt to set the task 494 | unsafe { 495 | inner.tx_task.set_task(cx); 496 | } 497 | 498 | // Update the state 499 | state = State::set_tx_task(&inner.state); 500 | 501 | if state.is_closed() { 502 | return Ready(()); 503 | } 504 | } 505 | 506 | Pending 507 | } 508 | } 509 | 510 | impl Drop for Sender { 511 | fn drop(&mut self) { 512 | if let Some(inner) = self.inner.as_ref() { 513 | inner.complete(); 514 | } 515 | } 516 | } 517 | 518 | impl Receiver { 519 | /// Prevents the associated [`Sender`] handle from sending a value. 520 | /// 521 | /// Any `send` operation which happens after calling `close` is guaranteed 522 | /// to fail. After calling `close`, [`try_recv`] should be called to 523 | /// receive a value if one was sent **before** the call to `close` 524 | /// completed. 525 | /// 526 | /// This function is useful to perform a graceful shutdown and ensure that a 527 | /// value will not be sent into the channel and never received. 528 | /// 529 | /// `close` is no-op if a message is already received or the channel 530 | /// is already closed. 531 | /// 532 | /// [`Sender`]: Sender 533 | /// [`try_recv`]: Receiver::try_recv 534 | /// 535 | /// # Examples 536 | /// 537 | /// Prevent a value from being sent 538 | /// 539 | /// ``` 540 | /// use local_sync::oneshot; 541 | /// use local_sync::oneshot::error::TryRecvError; 542 | /// 543 | /// #[monoio::main] 544 | /// async fn main() { 545 | /// let (tx, mut rx) = oneshot::channel(); 546 | /// 547 | /// assert!(!tx.is_closed()); 548 | /// 549 | /// rx.close(); 550 | /// 551 | /// assert!(tx.is_closed()); 552 | /// assert!(tx.send("never received").is_err()); 553 | /// 554 | /// match rx.try_recv() { 555 | /// Err(TryRecvError::Closed) => {} 556 | /// _ => unreachable!(), 557 | /// } 558 | /// } 559 | /// ``` 560 | /// 561 | /// Receive a value sent **before** calling `close` 562 | /// 563 | /// ``` 564 | /// use local_sync::oneshot; 565 | /// 566 | /// #[monoio::main] 567 | /// async fn main() { 568 | /// let (tx, mut rx) = oneshot::channel(); 569 | /// 570 | /// assert!(tx.send("will receive").is_ok()); 571 | /// 572 | /// rx.close(); 573 | /// 574 | /// let msg = rx.try_recv().unwrap(); 575 | /// assert_eq!(msg, "will receive"); 576 | /// } 577 | /// ``` 578 | pub fn close(&mut self) { 579 | if let Some(inner) = self.inner.as_ref() { 580 | inner.close(); 581 | } 582 | } 583 | 584 | /// Attempts to receive a value. 585 | /// 586 | /// If a pending value exists in the channel, it is returned. If no value 587 | /// has been sent, the current task **will not** be registered for 588 | /// future notification. 589 | /// 590 | /// This function is useful to call from outside the context of an 591 | /// asynchronous task. 592 | /// 593 | /// # Return 594 | /// 595 | /// - `Ok(T)` if a value is pending in the channel. 596 | /// - `Err(TryRecvError::Empty)` if no value has been sent yet. 597 | /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending 598 | /// a value. 599 | /// 600 | /// # Examples 601 | /// 602 | /// `try_recv` before a value is sent, then after. 603 | /// 604 | /// ``` 605 | /// use local_sync::oneshot; 606 | /// use local_sync::oneshot::error::TryRecvError; 607 | /// 608 | /// #[monoio::main] 609 | /// async fn main() { 610 | /// let (tx, mut rx) = oneshot::channel(); 611 | /// 612 | /// match rx.try_recv() { 613 | /// // The channel is currently empty 614 | /// Err(TryRecvError::Empty) => {} 615 | /// _ => unreachable!(), 616 | /// } 617 | /// 618 | /// // Send a value 619 | /// tx.send("hello").unwrap(); 620 | /// 621 | /// match rx.try_recv() { 622 | /// Ok(value) => assert_eq!(value, "hello"), 623 | /// _ => unreachable!(), 624 | /// } 625 | /// } 626 | /// ``` 627 | /// 628 | /// `try_recv` when the sender dropped before sending a value 629 | /// 630 | /// ``` 631 | /// use local_sync::oneshot; 632 | /// use local_sync::oneshot::error::TryRecvError; 633 | /// 634 | /// #[monoio::main] 635 | /// async fn main() { 636 | /// let (tx, mut rx) = oneshot::channel::<()>(); 637 | /// 638 | /// drop(tx); 639 | /// 640 | /// match rx.try_recv() { 641 | /// // The channel will never receive a value. 642 | /// Err(TryRecvError::Closed) => {} 643 | /// _ => unreachable!(), 644 | /// } 645 | /// } 646 | /// ``` 647 | pub fn try_recv(&mut self) -> Result { 648 | let result = if let Some(inner) = self.inner.as_ref() { 649 | let state = State(*inner.state.borrow()); 650 | 651 | if state.is_complete() { 652 | match unsafe { inner.consume_value() } { 653 | Some(value) => Ok(value), 654 | None => Err(TryRecvError::Closed), 655 | } 656 | } else if state.is_closed() { 657 | Err(TryRecvError::Closed) 658 | } else { 659 | // Not ready, this does not clear `inner` 660 | return Err(TryRecvError::Empty); 661 | } 662 | } else { 663 | Err(TryRecvError::Closed) 664 | }; 665 | 666 | self.inner = None; 667 | result 668 | } 669 | } 670 | 671 | impl Drop for Receiver { 672 | fn drop(&mut self) { 673 | if let Some(inner) = self.inner.as_ref() { 674 | inner.close(); 675 | } 676 | } 677 | } 678 | 679 | impl Future for Receiver { 680 | type Output = Result; 681 | 682 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 683 | // If `inner` is `None`, then `poll()` has already completed. 684 | let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() { 685 | ready!(inner.poll_recv(cx))? 686 | } else { 687 | panic!("called after complete"); 688 | }; 689 | 690 | self.inner = None; 691 | Ready(Ok(ret)) 692 | } 693 | } 694 | 695 | impl Inner { 696 | fn complete(&self) -> bool { 697 | let prev = State::set_complete(&self.state); 698 | 699 | if prev.is_closed() { 700 | return false; 701 | } 702 | 703 | if prev.is_rx_task_set() { 704 | // TODO: Consume waker? 705 | unsafe { 706 | self.rx_task.with_task(Waker::wake_by_ref); 707 | } 708 | } 709 | 710 | true 711 | } 712 | 713 | fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { 714 | // Load the state 715 | let mut state = State(*self.state.borrow()); 716 | 717 | if state.is_complete() { 718 | match unsafe { self.consume_value() } { 719 | Some(value) => Ready(Ok(value)), 720 | None => Ready(Err(RecvError(()))), 721 | } 722 | } else if state.is_closed() { 723 | Ready(Err(RecvError(()))) 724 | } else { 725 | if state.is_rx_task_set() { 726 | let will_notify = unsafe { self.rx_task.will_wake(cx) }; 727 | 728 | // Check if the task is still the same 729 | if !will_notify { 730 | // Unset the task 731 | state = State::unset_rx_task(&self.state); 732 | if state.is_complete() { 733 | // Set the flag again so that the waker is released in drop 734 | State::set_rx_task(&self.state); 735 | 736 | return match unsafe { self.consume_value() } { 737 | Some(value) => Ready(Ok(value)), 738 | None => Ready(Err(RecvError(()))), 739 | }; 740 | } else { 741 | unsafe { self.rx_task.drop_task() }; 742 | } 743 | } 744 | } 745 | 746 | if !state.is_rx_task_set() { 747 | // Attempt to set the task 748 | unsafe { 749 | self.rx_task.set_task(cx); 750 | } 751 | 752 | // Update the state 753 | state = State::set_rx_task(&self.state); 754 | 755 | if state.is_complete() { 756 | match unsafe { self.consume_value() } { 757 | Some(value) => Ready(Ok(value)), 758 | None => Ready(Err(RecvError(()))), 759 | } 760 | } else { 761 | Pending 762 | } 763 | } else { 764 | Pending 765 | } 766 | } 767 | } 768 | 769 | /// Called by `Receiver` to indicate that the value will never be received. 770 | fn close(&self) { 771 | let prev = State::set_closed(&self.state); 772 | 773 | if prev.is_tx_task_set() && !prev.is_complete() { 774 | unsafe { 775 | self.tx_task.with_task(Waker::wake_by_ref); 776 | } 777 | } 778 | } 779 | 780 | /// Consumes the value. This function does not check `state`. 781 | unsafe fn consume_value(&self) -> Option { 782 | let ptr = self.value.get(); 783 | (*ptr).take() 784 | } 785 | } 786 | 787 | impl Drop for Inner { 788 | fn drop(&mut self) { 789 | let state = State(*self.state.borrow()); 790 | 791 | if state.is_rx_task_set() { 792 | unsafe { 793 | self.rx_task.drop_task(); 794 | } 795 | } 796 | 797 | if state.is_tx_task_set() { 798 | unsafe { 799 | self.tx_task.drop_task(); 800 | } 801 | } 802 | } 803 | } 804 | 805 | impl fmt::Debug for Inner { 806 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 807 | fmt.debug_struct("Inner") 808 | .field("state", &self.state.borrow()) 809 | .finish() 810 | } 811 | } 812 | 813 | const RX_TASK_SET: usize = 0b00001; 814 | const VALUE_SENT: usize = 0b00010; 815 | const CLOSED: usize = 0b00100; 816 | const TX_TASK_SET: usize = 0b01000; 817 | 818 | impl State { 819 | fn new() -> State { 820 | State(0) 821 | } 822 | 823 | fn is_complete(self) -> bool { 824 | self.0 & VALUE_SENT == VALUE_SENT 825 | } 826 | 827 | fn set_complete(cell: &RefCell) -> State { 828 | let mut val = cell.borrow_mut(); 829 | *val |= VALUE_SENT; 830 | State(*val) 831 | } 832 | 833 | fn is_rx_task_set(self) -> bool { 834 | self.0 & RX_TASK_SET == RX_TASK_SET 835 | } 836 | 837 | fn set_rx_task(cell: &RefCell) -> State { 838 | let mut val = cell.borrow_mut(); 839 | *val |= RX_TASK_SET; 840 | State(*val) 841 | } 842 | 843 | fn unset_rx_task(cell: &RefCell) -> State { 844 | let mut val = cell.borrow_mut(); 845 | *val &= !RX_TASK_SET; 846 | State(*val) 847 | } 848 | 849 | fn is_closed(self) -> bool { 850 | self.0 & CLOSED == CLOSED 851 | } 852 | 853 | fn set_closed(cell: &RefCell) -> State { 854 | // Acquire because we want all later writes (attempting to poll) to be 855 | // ordered after this. 856 | let mut val = cell.borrow_mut(); 857 | *val |= CLOSED; 858 | State(*val) 859 | } 860 | 861 | fn set_tx_task(cell: &RefCell) -> State { 862 | let mut val = cell.borrow_mut(); 863 | *val |= TX_TASK_SET; 864 | State(*val) 865 | } 866 | 867 | fn unset_tx_task(cell: &RefCell) -> State { 868 | let mut val = cell.borrow_mut(); 869 | *val &= !TX_TASK_SET; 870 | State(*val) 871 | } 872 | 873 | fn is_tx_task_set(self) -> bool { 874 | self.0 & TX_TASK_SET == TX_TASK_SET 875 | } 876 | 877 | fn as_usize(self) -> usize { 878 | self.0 879 | } 880 | } 881 | 882 | impl fmt::Debug for State { 883 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 884 | fmt.debug_struct("State") 885 | .field("is_complete", &self.is_complete()) 886 | .field("is_closed", &self.is_closed()) 887 | .field("is_rx_task_set", &self.is_rx_task_set()) 888 | .field("is_tx_task_set", &self.is_tx_task_set()) 889 | .finish() 890 | } 891 | } 892 | 893 | #[cfg(test)] 894 | mod tests { 895 | use super::channel; 896 | 897 | #[monoio::test] 898 | async fn it_works() { 899 | let (tx, rx) = channel(); 900 | let join = monoio::spawn(async move { rx.await }); 901 | tx.send(1).unwrap(); 902 | assert_eq!(join.await.unwrap(), 1); 903 | } 904 | } 905 | -------------------------------------------------------------------------------- /src/broadcast.rs: -------------------------------------------------------------------------------- 1 | //! A multi-producer, multi-consumer broadcast channel. Each sent value is received by 2 | //! all consumers. 3 | //! 4 | //! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] 5 | //! instances. [`Sender`] handles are cloneable, allowing concurrent send and 6 | //! receive actions. [`Sender`] and [`Receiver`] are both non-thread-safe and should be 7 | //! used within the same thread. 8 | //! 9 | //! When a value is sent, **all** [`Receiver`] handles are notified and will 10 | //! receive the value. The value is stored once inside the channel and cloned on 11 | //! demand for each receiver. Once all receivers have received a clone of the value, 12 | //! the value is released from the channel. 13 | //! 14 | //! A channel is created by calling [`channel`], specifying the maximum number 15 | //! of messages the channel can retain at any given time. 16 | //! 17 | //! New [`Receiver`] handles are created by calling [`Sender::subscribe`]. The 18 | //! returned [`Receiver`] will receive values sent **after** the call to 19 | //! `subscribe`. 20 | //! 21 | //! This channel is also suitable for the single-producer multi-consumer 22 | //! use-case, where a single sender broadcasts values to many receivers. 23 | //! 24 | //! # Lagging 25 | //! 26 | //! As sent messages must be retained until **all** [`Receiver`] handles receive 27 | //! a clone, broadcast channels are susceptible to the "slow receiver" problem. 28 | //! In this case, all but one receiver are able to receive values at the rate 29 | //! they are sent. Because one receiver is stalled, the channel starts to fill 30 | //! up. 31 | //! 32 | //! This broadcast channel implementation handles this case by setting a hard 33 | //! upper bound on the number of values the channel may retain at any given 34 | //! time. This upper bound is passed to the [`channel`] function as an argument. 35 | //! 36 | //! If a value is sent when the channel is at capacity, the oldest value 37 | //! currently held by the channel is released. This frees up space for the new 38 | //! value. Any receiver that has not yet seen the released value will return 39 | //! [`RecvError::Lagged`] the next time [`recv`] is called. 40 | //! 41 | //! Once [`RecvError::Lagged`] is returned, the lagging receiver's position is 42 | //! updated to the oldest value contained by the channel. The next call to 43 | //! [`recv`] will return this value. 44 | //! 45 | //! This behavior enables a receiver to detect when it has lagged so far behind 46 | //! that data has been dropped. The caller may decide how to respond to this: 47 | //! either by aborting its task or by tolerating lost messages and resuming 48 | //! consumption of the channel. 49 | //! 50 | //! # Closing 51 | //! 52 | //! When **all** [`Sender`] handles have been dropped, no new values may be 53 | //! sent. At this point, the channel is "closed". Once a receiver has received 54 | //! all values retained by the channel, the next call to [`recv`] will return 55 | //! with [`RecvError::Closed`]. 56 | //! 57 | //! When a [`Receiver`] handle is dropped, any messages not read by the receiver 58 | //! will be marked as read. If this receiver was the only one not to have read 59 | //! that message, the message will be dropped at this point. 60 | //! 61 | //! # Examples 62 | //! 63 | //! Basic usage: 64 | //! 65 | //! ``` 66 | //! use local_sync::broadcast; 67 | //! 68 | //! #[monoio::main] 69 | //! async fn main() { 70 | //! let (tx, mut rx1) = broadcast::channel(16); 71 | //! let mut rx2 = tx.subscribe(); 72 | //! 73 | //! monoio::spawn(async move { 74 | //! assert_eq!(rx1.recv().await.unwrap(), 10); 75 | //! assert_eq!(rx1.recv().await.unwrap(), 20); 76 | //! }); 77 | //! 78 | //! monoio::spawn(async move { 79 | //! assert_eq!(rx2.recv().await.unwrap(), 10); 80 | //! assert_eq!(rx2.recv().await.unwrap(), 20); 81 | //! }); 82 | //! 83 | //! tx.send(10).unwrap(); 84 | //! tx.send(20).unwrap(); 85 | //! } 86 | //! ``` 87 | //! 88 | //! Handling lag: 89 | //! 90 | //! ``` 91 | //! use local_sync::broadcast::{self, error::RecvError}; 92 | //! 93 | //! #[monoio::main] 94 | //! async fn main() { 95 | //! let (tx, mut rx) = broadcast::channel(2); 96 | //! 97 | //! tx.send(10).unwrap(); 98 | //! tx.send(20).unwrap(); 99 | //! tx.send(30).unwrap(); 100 | //! 101 | //! // The receiver lagged behind 102 | //! assert!(matches!(rx.recv().await, Err(RecvError::Lagged(_)))); 103 | //! 104 | //! // At this point, we can abort or continue with lost messages 105 | //! 106 | //! assert_eq!(rx.recv().await.unwrap(), 20); 107 | //! assert_eq!(rx.recv().await.unwrap(), 30); 108 | //! } 109 | //! ``` 110 | 111 | use crate::wake_list::WakeList; 112 | use std::cell::{Cell, RefCell}; 113 | use std::future::Future; 114 | use std::pin::Pin; 115 | use std::rc::Rc; 116 | use std::task::{Context, Poll}; 117 | 118 | /// Create a bounded, multi-producer, multi-consumer channel where each sent 119 | /// value is broadcasted to all active receivers. 120 | /// 121 | /// **Note:** The actual capacity will be rounded up to the next power of 2. 122 | /// 123 | /// All data sent on [`Sender`] will become available on every active 124 | /// [`Receiver`] in the same order as it was sent. 125 | /// 126 | /// The `Sender` can be cloned to `send` to the same channel from multiple 127 | /// points in the process. New `Receiver` handles are created by calling 128 | /// [`Sender::subscribe`]. 129 | /// 130 | /// If all [`Receiver`] handles are dropped, the `send` method will return a 131 | /// [`SendError`]. Similarly, if all [`Sender`] handles are dropped, the [`recv`] 132 | /// method will return a [`RecvError`]. 133 | /// 134 | /// # Examples 135 | /// 136 | /// ``` 137 | /// use local_sync::broadcast; 138 | /// 139 | /// #[monoio::main] 140 | /// async fn main() { 141 | /// let (tx, mut rx1) = broadcast::channel(16); 142 | /// let mut rx2 = tx.subscribe(); 143 | /// 144 | /// monoio::spawn(async move { 145 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 146 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 147 | /// }); 148 | /// 149 | /// monoio::spawn(async move { 150 | /// assert_eq!(rx2.recv().await.unwrap(), 10); 151 | /// assert_eq!(rx2.recv().await.unwrap(), 20); 152 | /// }); 153 | /// 154 | /// tx.send(10).unwrap(); 155 | /// tx.send(20).unwrap(); 156 | /// } 157 | /// ``` 158 | /// 159 | /// # Panics 160 | /// 161 | /// This will panic if `capacity` is equal to `0`. 162 | pub fn channel(capacity: usize) -> (Sender, Receiver) { 163 | assert!(capacity > 0, "broadcast channel capacity cannot be zero"); 164 | 165 | // Round up to the next power of 2 166 | let cap = capacity.next_power_of_two(); 167 | let mut buffer = Vec::with_capacity(cap); 168 | 169 | for _ in 0..cap { 170 | buffer.push(RefCell::new(Slot { 171 | rem: Cell::new(0), 172 | pos: 0, 173 | val: RefCell::new(None), 174 | })); 175 | } 176 | 177 | let shared = Rc::new(Shared { 178 | buffer, 179 | mask: cap - 1, 180 | tail: RefCell::new(Tail { 181 | pos: 0, 182 | rx_cnt: 1, 183 | closed: false, 184 | wakers: WakeList::new(), 185 | }), 186 | num_tx: Cell::new(1), 187 | }); 188 | 189 | let tx = Sender { 190 | shared: shared.clone(), 191 | }; 192 | 193 | let rx = Receiver { shared, next: 0 }; 194 | 195 | (tx, rx) 196 | } 197 | 198 | /// Sending-half of the [`broadcast`] channel. 199 | /// Must only be used from the same thread. Messages can be sent with 200 | /// [`send`][Sender::send]. 201 | /// 202 | /// # Examples 203 | /// 204 | /// ``` 205 | /// use local_sync::broadcast; 206 | /// 207 | /// #[monoio::main] 208 | /// async fn main() { 209 | /// let (tx, mut rx1) = broadcast::channel(16); 210 | /// let mut rx2 = tx.subscribe(); 211 | /// 212 | /// monoio::spawn(async move { 213 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 214 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 215 | /// }); 216 | /// 217 | /// monoio::spawn(async move { 218 | /// assert_eq!(rx2.recv().await.unwrap(), 10); 219 | /// assert_eq!(rx2.recv().await.unwrap(), 20); 220 | /// }); 221 | /// 222 | /// tx.send(10).unwrap(); 223 | /// tx.send(20).unwrap(); 224 | /// } 225 | /// ``` 226 | /// 227 | /// [`broadcast`]: crate::broadcast 228 | pub struct Sender { 229 | shared: Rc>, 230 | } 231 | 232 | /// Receiving-half of the [`broadcast`] channel. 233 | /// 234 | /// Must not be used concurrently. Messages may be retrieved using 235 | /// [`recv`][Receiver::recv]. 236 | /// 237 | /// # Examples 238 | /// 239 | /// ``` 240 | /// use local_sync::broadcast; 241 | /// 242 | /// #[monoio::main] 243 | /// async fn main() { 244 | /// let (tx, mut rx1) = broadcast::channel(16); 245 | /// let mut rx2 = tx.subscribe(); 246 | /// 247 | /// monoio::spawn(async move { 248 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 249 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 250 | /// }); 251 | /// 252 | /// monoio::spawn(async move { 253 | /// assert_eq!(rx2.recv().await.unwrap(), 10); 254 | /// assert_eq!(rx2.recv().await.unwrap(), 20); 255 | /// }); 256 | /// 257 | /// tx.send(10).unwrap(); 258 | /// tx.send(20).unwrap(); 259 | /// } 260 | /// ``` 261 | /// 262 | /// [`broadcast`]: crate::broadcast 263 | pub struct Receiver { 264 | shared: Rc>, 265 | next: u64, 266 | } 267 | 268 | pub mod error { 269 | //! Broadcast error types 270 | //! 271 | //! This module contains the error types for the broadcast channel. 272 | 273 | use std::fmt; 274 | 275 | /// Error returned by the [`send`] function on a [`Sender`]. 276 | /// 277 | /// A **send** operation can only fail if there are no active receivers, 278 | /// implying that the message could never be received. The error contains the 279 | /// message being sent as a payload so it can be recovered. 280 | /// 281 | /// [`send`]: crate::broadcast::Sender::send 282 | /// [`Sender`]: crate::broadcast::Sender 283 | #[derive(Debug)] 284 | pub struct SendError(pub T); 285 | 286 | impl fmt::Display for SendError { 287 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 288 | write!(f, "channel closed") 289 | } 290 | } 291 | 292 | impl std::error::Error for SendError {} 293 | 294 | /// An error returned from the [`recv`] function on a [`Receiver`]. 295 | /// 296 | /// [`recv`]: crate::broadcast::Receiver::recv 297 | /// [`Receiver`]: crate::broadcast::Receiver 298 | #[derive(Debug, PartialEq, Eq, Clone)] 299 | pub enum RecvError { 300 | /// There are no more active senders implying no further messages will ever 301 | /// be sent. 302 | Closed, 303 | 304 | /// The receiver lagged too far behind. Attempting to receive again will 305 | /// return the oldest message still retained by the channel. 306 | /// 307 | /// Includes the number of skipped messages. 308 | Lagged(u64), 309 | } 310 | 311 | impl fmt::Display for RecvError { 312 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 313 | match self { 314 | RecvError::Closed => write!(f, "channel closed"), 315 | RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), 316 | } 317 | } 318 | } 319 | 320 | impl std::error::Error for RecvError {} 321 | 322 | /// An error returned from the [`try_recv`] function on a [`Receiver`]. 323 | /// 324 | /// [`try_recv`]: crate::broadcast::Receiver::try_recv 325 | /// [`Receiver`]: crate::broadcast::Receiver 326 | #[derive(Debug, PartialEq, Eq, Clone)] 327 | pub enum TryRecvError { 328 | /// The channel is currently empty. There are still active 329 | /// [`Sender`] handles, so data may yet become available. 330 | /// 331 | /// [`Sender`]: crate::broadcast::Sender 332 | Empty, 333 | 334 | /// There are no more active senders implying no further messages will ever 335 | /// be sent. 336 | Closed, 337 | 338 | /// The receiver lagged too far behind and has been forcibly disconnected. 339 | /// Attempting to receive again will return the oldest message still 340 | /// retained by the channel. 341 | /// 342 | /// Includes the number of skipped messages. 343 | Lagged(u64), 344 | } 345 | 346 | impl fmt::Display for TryRecvError { 347 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 348 | match self { 349 | TryRecvError::Empty => write!(f, "channel empty"), 350 | TryRecvError::Closed => write!(f, "channel closed"), 351 | TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), 352 | } 353 | } 354 | } 355 | 356 | impl std::error::Error for TryRecvError {} 357 | } 358 | 359 | use error::{RecvError, SendError, TryRecvError}; 360 | 361 | /// Data shared between senders and receivers. 362 | struct Shared { 363 | /// slots in the channel. 364 | buffer: Vec>>, 365 | 366 | /// Mask a position -> index. 367 | mask: usize, 368 | 369 | /// Tail of the queue. 370 | tail: RefCell, 371 | 372 | /// Number of outstanding Sender handles. 373 | num_tx: Cell, 374 | } 375 | 376 | /// Next position to write a value. 377 | struct Tail { 378 | /// Next position to write to. 379 | pos: u64, 380 | 381 | /// Number of active receivers. 382 | rx_cnt: usize, 383 | 384 | /// True if the channel is closed. 385 | closed: bool, 386 | 387 | /// Receivers waiting for a value. 388 | wakers: WakeList, 389 | } 390 | 391 | /// Slot in the buffer. 392 | struct Slot { 393 | /// Remaining number of receivers that are expected to see this value. 394 | /// 395 | /// When this goes to zero, the value is released. 396 | rem: Cell, 397 | 398 | /// Uniquely identifies the `send` stored in the slot. 399 | pos: u64, 400 | 401 | /// The value being broadcast. 402 | /// 403 | /// The value is set by `send`. When a reader drops, `rem` is decremented. 404 | /// When it hits zero, the value is dropped. 405 | val: RefCell>, 406 | } 407 | 408 | impl Sender { 409 | /// Attempts to send a value to all active [`Receiver`] handles. 410 | /// 411 | /// If this function returns an error, the value was not sent to any receivers 412 | /// and the channel has been closed. 413 | /// 414 | /// # Examples 415 | /// 416 | /// ``` 417 | /// use local_sync::broadcast; 418 | /// 419 | /// #[monoio::main] 420 | /// async fn main() { 421 | /// let (tx, mut rx1) = broadcast::channel(16); 422 | /// let mut rx2 = tx.subscribe(); 423 | /// 424 | /// tx.send(10).unwrap(); 425 | /// tx.send(20).unwrap(); 426 | /// 427 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 428 | /// assert_eq!(rx2.recv().await.unwrap(), 10); 429 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 430 | /// assert_eq!(rx2.recv().await.unwrap(), 20); 431 | /// } 432 | /// ``` 433 | pub fn send(&self, value: T) -> Result<(), SendError> { 434 | let mut tail = self.shared.tail.borrow_mut(); 435 | if tail.rx_cnt == 0 || tail.closed { 436 | return Err(SendError(value)); 437 | } 438 | 439 | let idx = tail.pos as usize & self.shared.mask; 440 | let slot = &self.shared.buffer[idx]; 441 | let mut slot = slot.borrow_mut(); 442 | 443 | slot.pos = tail.pos; 444 | slot.rem.set(tail.rx_cnt); 445 | *slot.val.borrow_mut() = Some(value); 446 | 447 | tail.pos = tail.pos.wrapping_add(1); 448 | 449 | // Wake all waiting receivers 450 | tail.wakers.wake_all(); 451 | 452 | Ok(()) 453 | } 454 | 455 | /// Creates a new [`Receiver`] handle that will receive values sent **after** 456 | /// this call to `subscribe`. 457 | /// 458 | /// # Examples 459 | /// 460 | /// ``` 461 | /// use local_sync::broadcast; 462 | /// 463 | /// #[monoio::main] 464 | /// async fn main() { 465 | /// let (tx, _rx) = broadcast::channel(16); 466 | /// 467 | /// // Will not be seen 468 | /// tx.send(10).unwrap(); 469 | /// 470 | /// let mut rx = tx.subscribe(); 471 | /// 472 | /// tx.send(20).unwrap(); 473 | /// 474 | /// let value = rx.recv().await.unwrap(); 475 | /// assert_eq!(20, value); 476 | /// } 477 | /// ``` 478 | pub fn subscribe(&self) -> Receiver { 479 | let shared = self.shared.clone(); 480 | new_receiver(shared) 481 | } 482 | 483 | /// Returns the number of active receivers. 484 | /// 485 | /// An active receiver is a [`Receiver`] handle returned from [`channel`] or 486 | /// [`subscribe`]. These are the handles that will receive values sent on 487 | /// this [`Sender`]. 488 | /// 489 | /// # Note 490 | /// 491 | /// It is not guaranteed that a sent message will reach this number of 492 | /// receivers. Active receivers may never call [`recv`] again before 493 | /// dropping. 494 | /// 495 | /// # Examples 496 | /// 497 | /// ``` 498 | /// use local_sync::broadcast; 499 | /// 500 | /// #[monoio::main] 501 | /// async fn main() { 502 | /// let (tx, _rx1) = broadcast::channel(16); 503 | /// 504 | /// assert_eq!(1, tx.receiver_count()); 505 | /// 506 | /// let mut _rx2 = tx.subscribe(); 507 | /// 508 | /// assert_eq!(2, tx.receiver_count()); 509 | /// 510 | /// tx.send(10).unwrap(); 511 | /// } 512 | /// ``` 513 | pub fn receiver_count(&self) -> usize { 514 | self.shared.tail.borrow().rx_cnt 515 | } 516 | 517 | /// Returns whether the channel is closed without needing to await. 518 | /// 519 | /// This happens when all receivers have been dropped. 520 | /// 521 | /// A return value of `true` means that a subsequent [`send`](Sender::send) 522 | /// will return an error. 523 | /// 524 | /// # Examples 525 | /// 526 | /// ``` 527 | /// use local_sync::broadcast; 528 | /// 529 | /// let (tx, rx) = broadcast::channel::<()>(100); 530 | /// 531 | /// assert!(!tx.is_closed()); 532 | /// 533 | /// drop(rx); 534 | /// 535 | /// assert!(tx.is_closed()); 536 | /// ``` 537 | pub fn is_closed(&self) -> bool { 538 | self.shared.tail.borrow().rx_cnt == 0 539 | } 540 | 541 | /// Closes the channel without sending a message. 542 | /// 543 | /// This prevents the channel from sending any new messages. Current 544 | /// receivers may still receive any values buffered, but will receive 545 | /// an error when attempting to receive additional messages after the buffer 546 | /// has been drained. 547 | /// 548 | /// # Examples 549 | /// 550 | /// ``` 551 | /// use local_sync::broadcast; 552 | /// use local_sync::broadcast::error::RecvError; 553 | /// 554 | /// #[monoio::main] 555 | /// async fn main() { 556 | /// let (tx, mut rx) = broadcast::channel(16); 557 | /// 558 | /// // Close the channel 559 | /// tx.close(); 560 | /// 561 | /// // After closing, receivers should get a Closed error 562 | /// assert_eq!(rx.recv().await, Err(RecvError::Closed)); 563 | /// 564 | /// // Sending after close should fail 565 | /// assert!(tx.send(10).is_err()); 566 | /// } 567 | /// ``` 568 | pub fn close(&self) { 569 | let mut tail = self.shared.tail.borrow_mut(); 570 | tail.closed = true; 571 | tail.wakers.wake_all(); 572 | } 573 | 574 | /// Returns `true` if senders belong to the same channel. 575 | /// 576 | /// # Examples 577 | /// 578 | /// ``` 579 | /// use local_sync::broadcast; 580 | /// 581 | /// let (tx1, _) = broadcast::channel::<()>(16); 582 | /// let tx2 = tx1.clone(); 583 | /// 584 | /// assert!(tx1.same_channel(&tx2)); 585 | /// 586 | /// let (tx3, _) = broadcast::channel::<()>(16); 587 | /// 588 | /// assert!(!tx1.same_channel(&tx3)); 589 | /// ``` 590 | pub fn same_channel(&self, other: &Self) -> bool { 591 | Rc::ptr_eq(&self.shared, &other.shared) 592 | } 593 | } 594 | 595 | /// Create a new `Receiver` which reads starting from the tail. 596 | fn new_receiver(shared: Rc>) -> Receiver { 597 | let mut tail = shared.tail.borrow_mut(); 598 | 599 | tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); 600 | 601 | let next = tail.pos; 602 | 603 | drop(tail); 604 | 605 | Receiver { shared, next } 606 | } 607 | 608 | impl Clone for Sender { 609 | fn clone(&self) -> Self { 610 | self.shared.num_tx.set(self.shared.num_tx.get() + 1); 611 | Self { 612 | shared: self.shared.clone(), 613 | } 614 | } 615 | } 616 | 617 | impl Drop for Sender { 618 | fn drop(&mut self) { 619 | let count = self.shared.num_tx.get(); 620 | self.shared.num_tx.set(count - 1); 621 | if count == 1 { 622 | let mut tail = self.shared.tail.borrow_mut(); 623 | tail.closed = true; 624 | // Wake all waiting receivers 625 | tail.wakers.wake_all(); 626 | } 627 | } 628 | } 629 | 630 | impl Receiver { 631 | pub fn resubscribe(&self) -> Self { 632 | let shared = self.shared.clone(); 633 | new_receiver(shared) 634 | } 635 | 636 | /// Returns the number of messages that were sent into the channel and that 637 | /// this [`Receiver`] has yet to receive. 638 | /// 639 | /// If the returned value from `len` is larger than the next largest power of 2 640 | /// of the capacity of the channel any call to [`recv`] will return an 641 | /// `Err(RecvError::Lagged)` and any call to [`try_recv`] will return an 642 | /// `Err(TryRecvError::Lagged)`, e.g. if the capacity of the channel is 10, 643 | /// [`recv`] will start to return `Err(RecvError::Lagged)` once `len` returns 644 | /// values larger than 16. 645 | /// 646 | /// # Examples 647 | /// 648 | /// ``` 649 | /// use local_sync::broadcast; 650 | /// 651 | /// #[monoio::main] 652 | /// async fn main() { 653 | /// let (tx, mut rx1) = broadcast::channel(16); 654 | /// 655 | /// tx.send(10).unwrap(); 656 | /// tx.send(20).unwrap(); 657 | /// 658 | /// assert_eq!(rx1.len(), 2); 659 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 660 | /// assert_eq!(rx1.len(), 1); 661 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 662 | /// assert_eq!(rx1.len(), 0); 663 | /// } 664 | /// ``` 665 | pub fn len(&self) -> usize { 666 | let tail = self.shared.tail.borrow(); 667 | (tail.pos - self.next) as usize 668 | } 669 | 670 | /// Returns true if there aren't any messages in the channel that the [`Receiver`] 671 | /// has yet to receive. 672 | /// 673 | /// # Examples 674 | /// 675 | /// ``` 676 | /// use local_sync::broadcast; 677 | /// 678 | /// #[monoio::main] 679 | /// async fn main() { 680 | /// let (tx, mut rx1) = broadcast::channel(16); 681 | /// 682 | /// assert!(rx1.is_empty()); 683 | /// 684 | /// tx.send(10).unwrap(); 685 | /// tx.send(20).unwrap(); 686 | /// 687 | /// assert!(!rx1.is_empty()); 688 | /// assert_eq!(rx1.recv().await.unwrap(), 10); 689 | /// assert_eq!(rx1.recv().await.unwrap(), 20); 690 | /// assert!(rx1.is_empty()); 691 | /// } 692 | /// ``` 693 | pub fn is_empty(&self) -> bool { 694 | self.len() == 0 695 | } 696 | 697 | /// Returns `true` if receivers belong to the same channel. 698 | /// 699 | /// # Examples 700 | /// 701 | /// ``` 702 | /// use local_sync::broadcast; 703 | /// 704 | /// #[monoio::main] 705 | /// async fn main() { 706 | /// let (tx, rx) = broadcast::channel::<()>(16); 707 | /// let rx2 = tx.subscribe(); 708 | /// 709 | /// assert!(rx.same_channel(&rx2)); 710 | /// 711 | /// let (_tx3, rx3) = broadcast::channel::<()>(16); 712 | /// 713 | /// assert!(!rx3.same_channel(&rx2)); 714 | /// } 715 | /// ``` 716 | pub fn same_channel(&self, other: &Self) -> bool { 717 | Rc::ptr_eq(&self.shared, &other.shared) 718 | } 719 | 720 | pub async fn recv(&mut self) -> Result { 721 | Recv { receiver: self }.await 722 | } 723 | 724 | /// Attempts to return a pending value on this receiver without awaiting. 725 | /// 726 | /// This is useful for a flavor of "optimistic check" before deciding to 727 | /// await on a receiver. 728 | /// 729 | /// Compared with [`recv`], this function has three failure cases instead of two 730 | /// (one for closed, one for an empty buffer, one for a lagging receiver). 731 | /// 732 | /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have 733 | /// dropped, indicating that no further values can be sent on the channel. 734 | /// 735 | /// If the [`Receiver`] handle falls behind, once the channel is full, newly 736 | /// sent values will overwrite old values. At this point, a call to [`recv`] 737 | /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s 738 | /// internal cursor is updated to point to the oldest value still held by 739 | /// the channel. A subsequent call to [`try_recv`] will return this value 740 | /// **unless** it has been since overwritten. If there are no values to 741 | /// receive, `Err(TryRecvError::Empty)` is returned. 742 | /// 743 | /// # Examples 744 | /// 745 | /// ``` 746 | /// use local_sync::broadcast; 747 | /// 748 | /// #[monoio::main] 749 | /// async fn main() { 750 | /// let (tx, mut rx) = broadcast::channel(16); 751 | /// 752 | /// assert!(rx.try_recv().is_err()); 753 | /// 754 | /// tx.send(10).unwrap(); 755 | /// 756 | /// let value = rx.try_recv().unwrap(); 757 | /// assert_eq!(10, value); 758 | /// } 759 | /// ``` 760 | pub fn try_recv(&mut self) -> Result { 761 | let tail = self.shared.tail.borrow(); 762 | if self.next == tail.pos { 763 | if tail.closed || self.shared.num_tx.get() == 0 { 764 | return Err(TryRecvError::Closed); 765 | } 766 | return Err(TryRecvError::Empty); 767 | } 768 | 769 | let idx = self.next as usize & self.shared.mask; 770 | let slot = &self.shared.buffer[idx]; 771 | let slot = slot.borrow(); 772 | 773 | if slot.pos != self.next { 774 | // We've lagged behind, calculate by how much 775 | let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); 776 | let missed = next.wrapping_sub(self.next); 777 | self.next = next; 778 | return Err(TryRecvError::Lagged(missed)); 779 | } 780 | 781 | let value = slot.val.borrow().clone(); 782 | if let Some(value) = value { 783 | let rem = slot.rem.get(); 784 | if rem > 1 { 785 | slot.rem.set(rem - 1); 786 | } else { 787 | *slot.val.borrow_mut() = None; 788 | } 789 | self.next = self.next.wrapping_add(1); 790 | Ok(value) 791 | } else { 792 | Err(TryRecvError::Closed) 793 | } 794 | } 795 | } 796 | 797 | /// Receive a value future. 798 | struct Recv<'a, T> { 799 | /// Receiver being waited on. 800 | receiver: &'a mut Receiver, 801 | } 802 | 803 | impl<'a, T: Clone> Future for Recv<'a, T> { 804 | type Output = Result; 805 | 806 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 807 | match self.receiver.try_recv() { 808 | Ok(value) => Poll::Ready(Ok(value)), 809 | Err(TryRecvError::Empty) => { 810 | // Register waker 811 | if self.receiver.shared.tail.borrow_mut().wakers.can_push() { 812 | self.receiver 813 | .shared 814 | .tail 815 | .borrow_mut() 816 | .wakers 817 | .push(cx.waker().clone()); 818 | } else { 819 | } 820 | Poll::Pending 821 | } 822 | Err(TryRecvError::Lagged(n)) => Poll::Ready(Err(RecvError::Lagged(n))), 823 | Err(TryRecvError::Closed) => Poll::Ready(Err(RecvError::Closed)), 824 | } 825 | } 826 | } 827 | 828 | impl Drop for Receiver { 829 | fn drop(&mut self) { 830 | self.shared.drop_receiver(self.next); 831 | } 832 | } 833 | 834 | impl Shared { 835 | fn drop_receiver(&self, next: u64) { 836 | let mut tail = self.tail.borrow_mut(); 837 | tail.rx_cnt -= 1; 838 | 839 | // Iterate from 'next' to 'tail.pos' to decrement 'rem' counters. 840 | for pos in next..tail.pos { 841 | let idx = (pos as usize) & self.mask; 842 | let slot = &self.buffer[idx]; 843 | let slot = slot.borrow(); 844 | 845 | if slot.pos == pos { 846 | if slot.rem.get() > 0 { 847 | slot.rem.set(slot.rem.get() - 1); 848 | } 849 | 850 | // If no receivers are waiting for this slot, drop the value. 851 | if slot.rem.get() == 0 { 852 | *slot.val.borrow_mut() = None; 853 | } 854 | } 855 | } 856 | 857 | // If no receivers are left and the channel is not closed, mark it as closed. 858 | if tail.rx_cnt == 0 && !tail.closed { 859 | tail.closed = true; 860 | tail.wakers.wake_all(); 861 | } 862 | } 863 | } 864 | 865 | #[cfg(test)] 866 | mod tests { 867 | use super::*; 868 | 869 | #[monoio::test] 870 | async fn basic_usage() { 871 | let (tx, mut rx1) = channel(16); 872 | let mut rx2 = tx.subscribe(); 873 | 874 | tx.send(10).unwrap(); 875 | tx.send(20).unwrap(); 876 | 877 | assert_eq!(rx1.recv().await.unwrap(), 10); 878 | assert_eq!(rx2.recv().await.unwrap(), 10); 879 | assert_eq!(rx1.recv().await.unwrap(), 20); 880 | assert_eq!(rx2.recv().await.unwrap(), 20); 881 | } 882 | 883 | #[monoio::test] 884 | async fn lagged_receiver() { 885 | let (tx, mut rx) = channel(2); 886 | 887 | tx.send(10).unwrap(); 888 | tx.send(20).unwrap(); 889 | tx.send(30).unwrap(); 890 | 891 | assert!(matches!(rx.recv().await, Err(RecvError::Lagged(_)))); 892 | assert_eq!(rx.recv().await.unwrap(), 20); 893 | assert_eq!(rx.recv().await.unwrap(), 30); 894 | } 895 | 896 | #[test] 897 | fn receiver_count_on_channel_constructor() { 898 | let (sender, _) = channel::(16); 899 | assert_eq!(sender.receiver_count(), 0); 900 | 901 | let rx_1 = sender.subscribe(); 902 | assert_eq!(sender.receiver_count(), 1); 903 | 904 | let rx_2 = rx_1.resubscribe(); 905 | assert_eq!(sender.receiver_count(), 2); 906 | 907 | let rx_3 = sender.subscribe(); 908 | assert_eq!(sender.receiver_count(), 3); 909 | 910 | drop(rx_3); 911 | drop(rx_1); 912 | assert_eq!(sender.receiver_count(), 1); 913 | 914 | drop(rx_2); 915 | assert_eq!(sender.receiver_count(), 0); 916 | } 917 | } 918 | -------------------------------------------------------------------------------- /src/semaphore.rs: -------------------------------------------------------------------------------- 1 | //! Semaphore borrowed from tokio. 2 | 3 | #![allow(unused)] 4 | 5 | use core::future::Future; 6 | use std::{ 7 | cell::{RefCell, UnsafeCell}, 8 | cmp, fmt, 9 | marker::PhantomPinned, 10 | pin::Pin, 11 | ptr::NonNull, 12 | task::{Context, Poll, Waker}, 13 | }; 14 | 15 | use crate::{ 16 | linked_list::{self, LinkedList}, 17 | wake_list::WakeList, 18 | }; 19 | 20 | /// Low level semaphore. 21 | pub(crate) struct Inner { 22 | waiters: RefCell, 23 | /// The current number of available permits in the semaphore. 24 | permits: RefCell, 25 | } 26 | 27 | struct Waitlist { 28 | queue: LinkedList::Target>, 29 | closed: bool, 30 | } 31 | 32 | /// Error returned from the [`Semaphore::try_acquire`] function. 33 | /// 34 | /// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire 35 | #[derive(Debug, PartialEq)] 36 | pub enum TryAcquireError { 37 | /// The semaphore has been [closed] and cannot issue new permits. 38 | /// 39 | /// [closed]: crate::sync::Semaphore::close 40 | Closed, 41 | 42 | /// The semaphore has no available permits. 43 | NoPermits, 44 | } 45 | 46 | /// Error returned from the [`Semaphore::acquire`] function. 47 | /// 48 | /// An `acquire` operation can only fail if the semaphore has been 49 | /// [closed]. 50 | /// 51 | /// [closed]: crate::sync::Semaphore::close 52 | /// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire 53 | #[derive(Debug)] 54 | pub struct AcquireError(()); 55 | 56 | pub(crate) struct Acquire<'a> { 57 | node: Waiter, 58 | semaphore: &'a Inner, 59 | num_permits: u32, 60 | queued: bool, 61 | } 62 | 63 | struct Waiter { 64 | /// The current state of the waiter. 65 | /// 66 | /// This is either the number of remaining permits required by 67 | /// the waiter, or a flag indicating that the waiter is not yet queued. 68 | state: RefCell, 69 | 70 | /// The waker to notify the task awaiting permits. 71 | /// 72 | /// # Safety 73 | /// 74 | /// This may only be accessed while the wait queue is locked. 75 | waker: UnsafeCell>, 76 | 77 | /// Intrusive linked-list pointers. 78 | /// 79 | /// # Safety 80 | /// 81 | /// This may only be accessed while the wait queue is locked. 82 | pointers: linked_list::Pointers, 83 | 84 | /// Should not be `Unpin`. 85 | _p: PhantomPinned, 86 | } 87 | 88 | impl Waiter { 89 | fn new(num_permits: u32) -> Self { 90 | Waiter { 91 | waker: UnsafeCell::new(None), 92 | state: RefCell::new(num_permits as usize), 93 | pointers: linked_list::Pointers::new(), 94 | _p: PhantomPinned, 95 | } 96 | } 97 | 98 | /// Assign permits to the waiter. 99 | /// 100 | /// Returns `true` if the waiter should be removed from the queue 101 | fn assign_permits(&self, n: &mut usize) -> bool { 102 | let mut curr = self.state.borrow_mut(); 103 | let assign = cmp::min(*curr, *n); 104 | *curr -= assign; 105 | *n -= assign; 106 | 107 | *curr == 0 108 | } 109 | } 110 | 111 | unsafe impl linked_list::Link for Waiter { 112 | // XXX: ideally, we would be able to use `Pin` here, to enforce the 113 | // invariant that list entries may not move while in the list. However, we 114 | // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` 115 | // type would require `Semaphore` to be generic over a lifetime. We can't 116 | // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether 117 | // or not they dereference to an `!Unpin` target. 118 | type Handle = NonNull; 119 | type Target = Waiter; 120 | 121 | fn as_raw(handle: &Self::Handle) -> NonNull { 122 | *handle 123 | } 124 | 125 | unsafe fn from_raw(ptr: NonNull) -> NonNull { 126 | ptr 127 | } 128 | 129 | unsafe fn pointers(mut target: NonNull) -> NonNull> { 130 | NonNull::from(&mut target.as_mut().pointers) 131 | } 132 | } 133 | 134 | impl Inner { 135 | /// The maximum number of permits which a semaphore can hold. 136 | /// 137 | /// Note that this reserves three bits of flags in the permit counter, but 138 | /// we only actually use one of them. However, the previous semaphore 139 | /// implementation used three bits, so we will continue to reserve them to 140 | /// avoid a breaking change if additional flags need to be added in the 141 | /// future. 142 | pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; 143 | const CLOSED: usize = 1; 144 | // The least-significant bit in the number of permits is reserved to use 145 | // as a flag indicating that the semaphore has been closed. Consequently 146 | // PERMIT_SHIFT is used to leave that bit for that purpose. 147 | const PERMIT_SHIFT: usize = 1; 148 | 149 | /// Creates a new semaphore with the initial number of permits 150 | /// 151 | /// Maximum number of permits on 32-bit platforms is `1<<29`. 152 | pub(crate) const fn new(mut permits: usize) -> Self { 153 | permits &= Self::MAX_PERMITS; 154 | 155 | Self { 156 | permits: RefCell::new(permits << Self::PERMIT_SHIFT), 157 | waiters: RefCell::new(Waitlist { 158 | queue: LinkedList::new(), 159 | closed: false, 160 | }), 161 | } 162 | } 163 | 164 | /// Returns the current number of available permits 165 | pub(crate) fn available_permits(&self) -> usize { 166 | *self.permits.borrow() >> Self::PERMIT_SHIFT 167 | } 168 | 169 | /// Adds `added` new permits to the semaphore. 170 | /// 171 | /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. 172 | pub(crate) fn release(&self, added: usize) { 173 | if added == 0 { 174 | return; 175 | } 176 | 177 | // Assign permits to the wait queue 178 | self.add_permits(added); 179 | } 180 | 181 | /// Closes the semaphore. This prevents the semaphore from issuing new 182 | /// permits and notifies all pending waiters. 183 | pub(crate) fn close(&self) { 184 | *self.permits.borrow_mut() |= Self::CLOSED; 185 | (*self.waiters.borrow_mut()).closed = true; 186 | 187 | let mut waiters = self.waiters.borrow_mut(); 188 | 189 | while let Some(mut waiter) = waiters.queue.pop_back() { 190 | let waker = unsafe { (*waiter.as_mut().waker.get()).take() }; 191 | if let Some(waker) = waker { 192 | waker.wake(); 193 | } 194 | } 195 | } 196 | 197 | /// Returns true if the semaphore is closed 198 | pub(crate) fn is_closed(&self) -> bool { 199 | *self.permits.borrow() & Self::CLOSED != 0 200 | } 201 | 202 | pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { 203 | assert!( 204 | num_permits as usize <= Self::MAX_PERMITS, 205 | "a semaphore may not have more than MAX_PERMITS permits ({})", 206 | Self::MAX_PERMITS 207 | ); 208 | let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; 209 | let mut curr = self.permits.borrow_mut(); 210 | 211 | // Has the semaphore closed? 212 | if *curr & Self::CLOSED == Self::CLOSED { 213 | return Err(TryAcquireError::Closed); 214 | } 215 | 216 | // Are there enough permits remaining? 217 | if *curr < num_permits { 218 | return Err(TryAcquireError::NoPermits); 219 | } 220 | 221 | *curr -= num_permits; 222 | Ok(()) 223 | } 224 | 225 | pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> { 226 | Acquire::new(self, num_permits) 227 | } 228 | 229 | /// Release `rem` permits to the semaphore's wait list, starting from the 230 | /// end of the queue. 231 | /// 232 | /// If `rem` exceeds the number of permits needed by the wait list, the 233 | /// remainder are assigned back to the semaphore. 234 | fn add_permits(&self, mut rem: usize) { 235 | let mut waiters = self.waiters.borrow_mut(); 236 | let mut wakers = WakeList::new(); 237 | let mut is_empty = false; 238 | while rem > 0 { 239 | 'inner: while wakers.can_push() { 240 | // Was the waiter assigned enough permits to wake it? 241 | match waiters.queue.last() { 242 | Some(waiter) => { 243 | if !waiter.assign_permits(&mut rem) { 244 | break 'inner; 245 | } 246 | } 247 | None => { 248 | is_empty = true; 249 | // If we assigned permits to all the waiters in the queue, and there are 250 | // still permits left over, assign them back to the semaphore. 251 | break 'inner; 252 | } 253 | }; 254 | let mut waiter = waiters.queue.pop_back().unwrap(); 255 | if let Some(waker) = unsafe { (*waiter.as_mut().waker.get()).take() } { 256 | wakers.push(waker); 257 | } 258 | } 259 | 260 | if rem > 0 && is_empty { 261 | let permits = rem; 262 | assert!( 263 | permits <= Self::MAX_PERMITS, 264 | "cannot add more than MAX_PERMITS permits ({})", 265 | Self::MAX_PERMITS 266 | ); 267 | *self.permits.borrow_mut() += rem << Self::PERMIT_SHIFT; 268 | rem = 0; 269 | } 270 | 271 | wakers.wake_all(); 272 | } 273 | 274 | assert_eq!(rem, 0); 275 | } 276 | 277 | fn poll_acquire( 278 | &self, 279 | cx: &mut Context<'_>, 280 | num_permits: u32, 281 | node: Pin<&mut Waiter>, 282 | queued: bool, 283 | ) -> Poll> { 284 | let needed = if queued { 285 | *node.state.borrow() << Self::PERMIT_SHIFT 286 | } else { 287 | (num_permits as usize) << Self::PERMIT_SHIFT 288 | }; 289 | 290 | let mut curr = self.permits.borrow_mut(); 291 | 292 | // If closed, return error immediately. 293 | if *curr & Self::CLOSED > 0 { 294 | return Poll::Ready(Err(AcquireError::closed())); 295 | } 296 | // If the current permits is enough and not queued, assign permit 297 | // and return ok immediately. 298 | if *curr >= needed && !queued { 299 | *curr -= needed; 300 | return Poll::Ready(Ok(())); 301 | } 302 | 303 | // Clear permits and assign it. 304 | let mut permits = *curr >> Self::PERMIT_SHIFT; 305 | *curr = 0; 306 | drop(curr); 307 | if node.assign_permits(&mut permits) { 308 | // TODO: may never be here? 309 | self.add_permits(permits); 310 | return Poll::Ready(Ok(())); 311 | } 312 | 313 | // Replace waker if needed. 314 | let waker = unsafe { &mut *node.waker.get() }; 315 | // Do we need to register the new waker? 316 | if waker 317 | .as_ref() 318 | .map(|waker| !waker.will_wake(cx.waker())) 319 | .unwrap_or(true) 320 | { 321 | *waker = Some(cx.waker().clone()); 322 | } 323 | 324 | // If the waiter is not already in the wait queue, enqueue it. 325 | if !queued { 326 | let node = unsafe { 327 | let node = Pin::into_inner_unchecked(node) as *mut _; 328 | NonNull::new_unchecked(node) 329 | }; 330 | 331 | self.waiters.borrow_mut().queue.push_front(node); 332 | } 333 | 334 | Poll::Pending 335 | } 336 | } 337 | 338 | impl fmt::Debug for Inner { 339 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 340 | fmt.debug_struct("Semaphore") 341 | .field("permits", &self.available_permits()) 342 | .finish() 343 | } 344 | } 345 | 346 | impl Future for Acquire<'_> { 347 | type Output = Result<(), AcquireError>; 348 | 349 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 350 | let (node, semaphore, needed, queued) = self.project(); 351 | 352 | match semaphore.poll_acquire(cx, needed, node, *queued) { 353 | Poll::Pending => { 354 | *queued = true; 355 | Poll::Pending 356 | } 357 | Poll::Ready(r) => { 358 | r?; 359 | *queued = false; 360 | Poll::Ready(Ok(())) 361 | } 362 | } 363 | } 364 | } 365 | 366 | impl<'a> Acquire<'a> { 367 | fn new(semaphore: &'a Inner, num_permits: u32) -> Self { 368 | Self { 369 | node: Waiter::new(num_permits), 370 | semaphore, 371 | num_permits, 372 | queued: false, 373 | } 374 | } 375 | 376 | fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Inner, u32, &mut bool) { 377 | fn is_unpin() {} 378 | unsafe { 379 | // Safety: all fields other than `node` are `Unpin` 380 | 381 | is_unpin::<&Inner>(); 382 | is_unpin::<&mut bool>(); 383 | is_unpin::(); 384 | 385 | let this = self.get_unchecked_mut(); 386 | ( 387 | Pin::new_unchecked(&mut this.node), 388 | this.semaphore, 389 | this.num_permits, 390 | &mut this.queued, 391 | ) 392 | } 393 | } 394 | } 395 | 396 | impl Drop for Acquire<'_> { 397 | fn drop(&mut self) { 398 | // If the future is completed, there is no node in the wait list, so we 399 | // can skip acquiring the lock. 400 | if !self.queued { 401 | return; 402 | } 403 | 404 | { 405 | // This is where we ensure safety. The future is being dropped, 406 | // which means we must ensure that the waiter entry is no longer stored 407 | // in the linked list. 408 | let mut waiters = self.semaphore.waiters.borrow_mut(); 409 | 410 | // remove the entry from the list 411 | let node = NonNull::from(&mut self.node); 412 | // Safety: we have locked the wait list. 413 | unsafe { waiters.queue.remove(node) }; 414 | } 415 | 416 | let acquired_permits = self.num_permits as usize - *self.node.state.borrow(); 417 | if acquired_permits > 0 { 418 | self.semaphore.add_permits(acquired_permits); 419 | } 420 | } 421 | } 422 | 423 | impl AcquireError { 424 | fn closed() -> AcquireError { 425 | AcquireError(()) 426 | } 427 | } 428 | 429 | impl fmt::Display for AcquireError { 430 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 431 | write!(fmt, "semaphore closed") 432 | } 433 | } 434 | 435 | impl std::error::Error for AcquireError {} 436 | 437 | impl TryAcquireError { 438 | /// Returns `true` if the error was caused by a closed semaphore. 439 | #[allow(dead_code)] // may be used later! 440 | pub(crate) fn is_closed(&self) -> bool { 441 | matches!(self, TryAcquireError::Closed) 442 | } 443 | 444 | /// Returns `true` if the error was caused by calling `try_acquire` on a 445 | /// semaphore with no available permits. 446 | #[allow(dead_code)] // may be used later! 447 | pub(crate) fn is_no_permits(&self) -> bool { 448 | matches!(self, TryAcquireError::NoPermits) 449 | } 450 | } 451 | 452 | impl fmt::Display for TryAcquireError { 453 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 454 | match self { 455 | TryAcquireError::Closed => write!(fmt, "semaphore closed"), 456 | TryAcquireError::NoPermits => write!(fmt, "no permits available"), 457 | } 458 | } 459 | } 460 | 461 | impl std::error::Error for TryAcquireError {} 462 | 463 | /// Counting semaphore performing asynchronous permit acquisition. 464 | /// 465 | /// A semaphore maintains a set of permits. Permits are used to synchronize 466 | /// access to a shared resource. A semaphore differs from a mutex in that it 467 | /// can allow more than one concurrent caller to access the shared resource at a 468 | /// time. 469 | /// 470 | /// When `acquire` is called and the semaphore has remaining permits, the 471 | /// function immediately returns a permit. However, if no remaining permits are 472 | /// available, `acquire` (asynchronously) waits until an outstanding permit is 473 | /// dropped. At this point, the freed permit is assigned to the caller. 474 | /// 475 | /// This `Semaphore` is fair, which means that permits are given out in the order 476 | /// they were requested. This fairness is also applied when `acquire_many` gets 477 | /// involved, so if a call to `acquire_many` at the front of the queue requests 478 | /// more permits than currently available, this can prevent a call to `acquire` 479 | /// from completing, even if the semaphore has enough permits complete the call 480 | /// to `acquire`. 481 | /// 482 | /// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`] 483 | /// utility. 484 | /// 485 | /// # Examples 486 | /// 487 | /// Basic usage: 488 | /// 489 | /// ``` 490 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 491 | /// 492 | /// #[monoio::main] 493 | /// async fn main() { 494 | /// let semaphore = Semaphore::new(3); 495 | /// 496 | /// let a_permit = semaphore.acquire().await.unwrap(); 497 | /// let two_permits = semaphore.acquire_many(2).await.unwrap(); 498 | /// 499 | /// assert_eq!(semaphore.available_permits(), 0); 500 | /// 501 | /// let permit_attempt = semaphore.try_acquire(); 502 | /// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); 503 | /// } 504 | /// ``` 505 | /// 506 | /// Use [`Semaphore::acquire_owned`] to move permits across tasks: 507 | /// 508 | /// ``` 509 | /// use std::rc::Rc; 510 | /// use local_sync::semaphore::Semaphore; 511 | /// 512 | /// #[monoio::main] 513 | /// async fn main() { 514 | /// let semaphore = Rc::new(Semaphore::new(3)); 515 | /// let mut join_handles = Vec::new(); 516 | /// 517 | /// for _ in 0..5 { 518 | /// let permit = semaphore.clone().acquire_owned().await.unwrap(); 519 | /// join_handles.push(monoio::spawn(async move { 520 | /// // perform task... 521 | /// // explicitly own `permit` in the task 522 | /// drop(permit); 523 | /// })); 524 | /// } 525 | /// 526 | /// for handle in join_handles { 527 | /// handle.await; 528 | /// } 529 | /// } 530 | /// ``` 531 | /// 532 | /// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html 533 | /// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned 534 | #[derive(Debug)] 535 | pub struct Semaphore(Inner); 536 | 537 | /// A permit from the semaphore. 538 | /// 539 | /// This type is created by the [`acquire`] method. 540 | /// 541 | /// [`acquire`]: crate::sync::Semaphore::acquire() 542 | #[must_use] 543 | #[derive(Debug)] 544 | pub struct SemaphorePermit<'a> { 545 | sem: &'a Semaphore, 546 | permits: u32, 547 | } 548 | 549 | /// An owned permit from the semaphore. 550 | /// 551 | /// This type is created by the [`acquire_owned`] method. 552 | /// 553 | /// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() 554 | #[must_use] 555 | #[derive(Debug)] 556 | pub struct OwnedSemaphorePermit { 557 | sem: std::rc::Rc, 558 | permits: u32, 559 | } 560 | 561 | pub struct AcquireResult<'a>(Acquire<'a>, &'a Semaphore, u32); 562 | 563 | impl<'a> Future for AcquireResult<'a> { 564 | type Output = Result, AcquireError>; 565 | 566 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 567 | let sem = self.1; 568 | let permits = self.2; 569 | let inner = unsafe { self.map_unchecked_mut(|me| &mut me.0) }; 570 | futures_util::ready!(inner.poll(cx))?; 571 | Poll::Ready(Ok(SemaphorePermit { sem, permits })) 572 | } 573 | } 574 | 575 | impl Semaphore { 576 | /// Creates a new semaphore with the initial number of permits. 577 | pub const fn new(permits: usize) -> Self { 578 | Self(Inner::new(permits)) 579 | } 580 | 581 | /// Returns the current number of available permits. 582 | pub fn available_permits(&self) -> usize { 583 | self.0.available_permits() 584 | } 585 | 586 | /// Adds `n` new permits to the semaphore. 587 | /// 588 | /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. 589 | pub fn add_permits(&self, n: usize) { 590 | self.0.release(n); 591 | } 592 | 593 | /// Acquires a permit from the semaphore. 594 | /// 595 | /// If the semaphore has been closed, this returns an [`AcquireError`]. 596 | /// Otherwise, this returns a [`SemaphorePermit`] representing the 597 | /// acquired permit. 598 | /// 599 | /// # Cancel safety 600 | /// 601 | /// This method uses a queue to fairly distribute permits in the order they 602 | /// were requested. Cancelling a call to `acquire` makes you lose your place 603 | /// in the queue. 604 | /// 605 | /// # Examples 606 | /// 607 | /// ``` 608 | /// use local_sync::semaphore::Semaphore; 609 | /// 610 | /// #[monoio::main] 611 | /// async fn main() { 612 | /// let semaphore = Semaphore::new(2); 613 | /// 614 | /// let permit_1 = semaphore.acquire().await.unwrap(); 615 | /// assert_eq!(semaphore.available_permits(), 1); 616 | /// 617 | /// let permit_2 = semaphore.acquire().await.unwrap(); 618 | /// assert_eq!(semaphore.available_permits(), 0); 619 | /// 620 | /// drop(permit_1); 621 | /// assert_eq!(semaphore.available_permits(), 1); 622 | /// } 623 | /// ``` 624 | /// 625 | /// [`AcquireError`]: crate::sync::AcquireError 626 | /// [`SemaphorePermit`]: crate::sync::SemaphorePermit 627 | pub fn acquire(&self) -> AcquireResult<'_> { 628 | let acq = self.0.acquire(1); 629 | AcquireResult(acq, self, 1) 630 | } 631 | 632 | /// Acquires `n` permits from the semaphore. 633 | /// 634 | /// If the semaphore has been closed, this returns an [`AcquireError`]. 635 | /// Otherwise, this returns a [`SemaphorePermit`] representing the 636 | /// acquired permits. 637 | /// 638 | /// # Cancel safety 639 | /// 640 | /// This method uses a queue to fairly distribute permits in the order they 641 | /// were requested. Cancelling a call to `acquire_many` makes you lose your 642 | /// place in the queue. 643 | /// 644 | /// # Examples 645 | /// 646 | /// ``` 647 | /// use local_sync::semaphore::Semaphore; 648 | /// 649 | /// #[monoio::main] 650 | /// async fn main() { 651 | /// let semaphore = Semaphore::new(5); 652 | /// 653 | /// let permit = semaphore.acquire_many(3).await.unwrap(); 654 | /// assert_eq!(semaphore.available_permits(), 2); 655 | /// } 656 | /// ``` 657 | /// 658 | /// [`AcquireError`]: crate::sync::AcquireError 659 | /// [`SemaphorePermit`]: crate::sync::SemaphorePermit 660 | pub fn acquire_many(&self, n: u32) -> AcquireResult<'_> { 661 | let acq = self.0.acquire(n); 662 | AcquireResult(acq, self, n) 663 | } 664 | 665 | /// Tries to acquire a permit from the semaphore. 666 | /// 667 | /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] 668 | /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, 669 | /// this returns a [`SemaphorePermit`] representing the acquired permits. 670 | /// 671 | /// # Examples 672 | /// 673 | /// ``` 674 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 675 | /// 676 | /// # fn main() { 677 | /// let semaphore = Semaphore::new(2); 678 | /// 679 | /// let permit_1 = semaphore.try_acquire().unwrap(); 680 | /// assert_eq!(semaphore.available_permits(), 1); 681 | /// 682 | /// let permit_2 = semaphore.try_acquire().unwrap(); 683 | /// assert_eq!(semaphore.available_permits(), 0); 684 | /// 685 | /// let permit_3 = semaphore.try_acquire(); 686 | /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); 687 | /// # } 688 | /// ``` 689 | /// 690 | /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed 691 | /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits 692 | /// [`SemaphorePermit`]: crate::sync::SemaphorePermit 693 | pub fn try_acquire(&self) -> Result, TryAcquireError> { 694 | match self.0.try_acquire(1) { 695 | Ok(_) => Ok(SemaphorePermit { 696 | sem: self, 697 | permits: 1, 698 | }), 699 | Err(e) => Err(e), 700 | } 701 | } 702 | 703 | /// Tries to acquire `n` permits from the semaphore. 704 | /// 705 | /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] 706 | /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. 707 | /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. 708 | /// 709 | /// # Examples 710 | /// 711 | /// ``` 712 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 713 | /// 714 | /// # fn main() { 715 | /// let semaphore = Semaphore::new(4); 716 | /// 717 | /// let permit_1 = semaphore.try_acquire_many(3).unwrap(); 718 | /// assert_eq!(semaphore.available_permits(), 1); 719 | /// 720 | /// let permit_2 = semaphore.try_acquire_many(2); 721 | /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); 722 | /// # } 723 | /// ``` 724 | /// 725 | /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed 726 | /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits 727 | /// [`SemaphorePermit`]: crate::sync::SemaphorePermit 728 | pub fn try_acquire_many(&self, n: u32) -> Result, TryAcquireError> { 729 | match self.0.try_acquire(n) { 730 | Ok(_) => Ok(SemaphorePermit { 731 | sem: self, 732 | permits: n, 733 | }), 734 | Err(e) => Err(e), 735 | } 736 | } 737 | 738 | /// Acquires a permit from the semaphore. 739 | /// 740 | /// The semaphore must be wrapped in an [`Rc`] to call this method. 741 | /// If the semaphore has been closed, this returns an [`AcquireError`]. 742 | /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the 743 | /// acquired permit. 744 | /// 745 | /// # Cancel safety 746 | /// 747 | /// This method uses a queue to fairly distribute permits in the order they 748 | /// were requested. Cancelling a call to `acquire_owned` makes you lose your 749 | /// place in the queue. 750 | /// 751 | /// # Examples 752 | /// 753 | /// ``` 754 | /// use std::rc::Rc; 755 | /// use local_sync::semaphore::Semaphore; 756 | /// 757 | /// #[monoio::main] 758 | /// async fn main() { 759 | /// let semaphore = Rc::new(Semaphore::new(3)); 760 | /// let mut join_handles = Vec::new(); 761 | /// 762 | /// for _ in 0..5 { 763 | /// let permit = semaphore.clone().acquire_owned().await.unwrap(); 764 | /// join_handles.push(monoio::spawn(async move { 765 | /// // perform task... 766 | /// // explicitly own `permit` in the task 767 | /// drop(permit); 768 | /// })); 769 | /// } 770 | /// 771 | /// for handle in join_handles { 772 | /// handle.await; 773 | /// } 774 | /// } 775 | /// ``` 776 | /// 777 | /// [`Rc`]: std::sync::Rc 778 | /// [`AcquireError`]: crate::sync::AcquireError 779 | /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit 780 | pub async fn acquire_owned( 781 | self: std::rc::Rc, 782 | ) -> Result { 783 | self.0.acquire(1).await?; 784 | Ok(OwnedSemaphorePermit { 785 | sem: self, 786 | permits: 1, 787 | }) 788 | } 789 | 790 | /// Acquires `n` permits from the semaphore. 791 | /// 792 | /// The semaphore must be wrapped in an [`Rc`] to call this method. 793 | /// If the semaphore has been closed, this returns an [`AcquireError`]. 794 | /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the 795 | /// acquired permit. 796 | /// 797 | /// # Cancel safety 798 | /// 799 | /// This method uses a queue to fairly distribute permits in the order they 800 | /// were requested. Cancelling a call to `acquire_many_owned` makes you lose 801 | /// your place in the queue. 802 | /// 803 | /// # Examples 804 | /// 805 | /// ``` 806 | /// use std::rc::Rc; 807 | /// use local_sync::semaphore::Semaphore; 808 | /// 809 | /// #[monoio::main] 810 | /// async fn main() { 811 | /// let semaphore = Rc::new(Semaphore::new(10)); 812 | /// let mut join_handles = Vec::new(); 813 | /// 814 | /// for _ in 0..5 { 815 | /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap(); 816 | /// join_handles.push(monoio::spawn(async move { 817 | /// // perform task... 818 | /// // explicitly own `permit` in the task 819 | /// drop(permit); 820 | /// })); 821 | /// } 822 | /// 823 | /// for handle in join_handles { 824 | /// handle.await; 825 | /// } 826 | /// } 827 | /// ``` 828 | /// 829 | /// [`Rc`]: std::sync::Rc 830 | /// [`AcquireError`]: crate::sync::AcquireError 831 | /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit 832 | pub async fn acquire_many_owned( 833 | self: std::rc::Rc, 834 | n: u32, 835 | ) -> Result { 836 | self.0.acquire(n).await?; 837 | Ok(OwnedSemaphorePermit { 838 | sem: self, 839 | permits: n, 840 | }) 841 | } 842 | 843 | /// Tries to acquire a permit from the semaphore. 844 | /// 845 | /// The semaphore must be wrapped in an [`Rc`] to call this method. If 846 | /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] 847 | /// and a [`TryAcquireError::NoPermits`] if there are no permits left. 848 | /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the 849 | /// acquired permit. 850 | /// 851 | /// # Examples 852 | /// 853 | /// ``` 854 | /// use std::rc::Rc; 855 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 856 | /// 857 | /// # fn main() { 858 | /// let semaphore = Rc::new(Semaphore::new(2)); 859 | /// 860 | /// let permit_1 = Rc::clone(&semaphore).try_acquire_owned().unwrap(); 861 | /// assert_eq!(semaphore.available_permits(), 1); 862 | /// 863 | /// let permit_2 = Rc::clone(&semaphore).try_acquire_owned().unwrap(); 864 | /// assert_eq!(semaphore.available_permits(), 0); 865 | /// 866 | /// let permit_3 = semaphore.try_acquire_owned(); 867 | /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); 868 | /// # } 869 | /// ``` 870 | /// 871 | /// [`Rc`]: std::sync::Rc 872 | /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed 873 | /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits 874 | /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit 875 | pub fn try_acquire_owned( 876 | self: std::rc::Rc, 877 | ) -> Result { 878 | match self.0.try_acquire(1) { 879 | Ok(_) => Ok(OwnedSemaphorePermit { 880 | sem: self, 881 | permits: 1, 882 | }), 883 | Err(e) => Err(e), 884 | } 885 | } 886 | 887 | /// Tries to acquire `n` permits from the semaphore. 888 | /// 889 | /// The semaphore must be wrapped in an [`Rc`] to call this method. If 890 | /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] 891 | /// and a [`TryAcquireError::NoPermits`] if there are no permits left. 892 | /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the 893 | /// acquired permit. 894 | /// 895 | /// # Examples 896 | /// 897 | /// ``` 898 | /// use std::rc::Rc; 899 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 900 | /// 901 | /// # fn main() { 902 | /// let semaphore = Rc::new(Semaphore::new(4)); 903 | /// 904 | /// let permit_1 = Rc::clone(&semaphore).try_acquire_many_owned(3).unwrap(); 905 | /// assert_eq!(semaphore.available_permits(), 1); 906 | /// 907 | /// let permit_2 = semaphore.try_acquire_many_owned(2); 908 | /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); 909 | /// # } 910 | /// ``` 911 | /// 912 | /// [`Rc`]: std::sync::Rc 913 | /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed 914 | /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits 915 | /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit 916 | pub fn try_acquire_many_owned( 917 | self: std::rc::Rc, 918 | n: u32, 919 | ) -> Result { 920 | match self.0.try_acquire(n) { 921 | Ok(_) => Ok(OwnedSemaphorePermit { 922 | sem: self, 923 | permits: n, 924 | }), 925 | Err(e) => Err(e), 926 | } 927 | } 928 | 929 | /// Closes the semaphore. 930 | /// 931 | /// This prevents the semaphore from issuing new permits and notifies all pending waiters. 932 | /// 933 | /// # Examples 934 | /// 935 | /// ``` 936 | /// use local_sync::semaphore::{Semaphore, TryAcquireError}; 937 | /// use std::rc::Rc; 938 | /// 939 | /// #[monoio::main] 940 | /// async fn main() { 941 | /// let semaphore = Rc::new(Semaphore::new(1)); 942 | /// let semaphore2 = semaphore.clone(); 943 | /// 944 | /// monoio::spawn(async move { 945 | /// let permit = semaphore.acquire_many(2).await; 946 | /// assert!(permit.is_err()); 947 | /// println!("waiter received error"); 948 | /// }); 949 | /// 950 | /// println!("closing semaphore"); 951 | /// semaphore2.close(); 952 | /// 953 | /// // Cannot obtain more permits 954 | /// assert_eq!(semaphore2.try_acquire().err(), Some(TryAcquireError::Closed)) 955 | /// } 956 | /// ``` 957 | pub fn close(&self) { 958 | self.0.close(); 959 | } 960 | 961 | /// Returns true if the semaphore is closed 962 | pub fn is_closed(&self) -> bool { 963 | self.0.is_closed() 964 | } 965 | } 966 | 967 | impl<'a> SemaphorePermit<'a> { 968 | /// Forgets the permit **without** releasing it back to the semaphore. 969 | /// This can be used to reduce the amount of permits available from a 970 | /// semaphore. 971 | pub fn forget(mut self) { 972 | self.permits = 0; 973 | } 974 | } 975 | 976 | impl OwnedSemaphorePermit { 977 | /// Forgets the permit **without** releasing it back to the semaphore. 978 | /// This can be used to reduce the amount of permits available from a 979 | /// semaphore. 980 | pub fn forget(mut self) { 981 | self.permits = 0; 982 | } 983 | } 984 | 985 | impl<'a> Drop for SemaphorePermit<'_> { 986 | fn drop(&mut self) { 987 | self.sem.add_permits(self.permits as usize); 988 | } 989 | } 990 | 991 | impl Drop for OwnedSemaphorePermit { 992 | fn drop(&mut self) { 993 | self.sem.add_permits(self.permits as usize); 994 | } 995 | } 996 | 997 | #[cfg(test)] 998 | mod tests { 999 | use super::{Inner, Semaphore}; 1000 | 1001 | #[monoio::test] 1002 | async fn inner_works() { 1003 | let s = Inner::new(10); 1004 | for _ in 0..10 { 1005 | s.acquire(1).await.unwrap(); 1006 | } 1007 | } 1008 | 1009 | #[monoio::test] 1010 | async fn inner_release_after_acquire() { 1011 | let s = std::rc::Rc::new(Inner::new(0)); 1012 | 1013 | let s_move = s.clone(); 1014 | let join = monoio::spawn(async move { 1015 | let _ = s_move.acquire(1).await.unwrap(); 1016 | let _ = s_move.acquire(1).await.unwrap(); 1017 | }); 1018 | s.release(2); 1019 | join.await; 1020 | } 1021 | 1022 | #[monoio::test] 1023 | async fn it_works() { 1024 | let s = Semaphore::new(0); 1025 | s.add_permits(1); 1026 | let _ = s.acquire().await.unwrap(); 1027 | } 1028 | } 1029 | --------------------------------------------------------------------------------