├── .gitignore ├── CHANGELOG.md ├── src ├── error.rs ├── utils.rs ├── options.rs ├── lib.rs ├── response.rs ├── callback.rs ├── method.rs └── client.rs ├── Cargo.toml ├── LICENSE ├── tests ├── test.rs └── method.rs └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | ## 0.3.0 4 | 5 | - Fix struct publicity for response structs 6 | - Update `tokio-tungstenite` to `0.17` 7 | - Methods that require GID are now using `&str` 8 | 9 | ## 0.4.0 10 | 11 | - Refactor code to use channels instead of mutexes 12 | - Remove default timeout 13 | - Rename `Hooks` to `Callbacks` 14 | 15 | ## 0.5.0 16 | 17 | Dependency updates: 18 | 19 | ```toml 20 | tokio-tungstenite = "0.21" 21 | base64 = "0.22" 22 | snafu = "0.8" 23 | serde_with = { version = "3", features = ["chrono"] } 24 | ``` 25 | 26 | - Improve callback execution to address possible execution miss. 27 | - Fix `announce_list` type . 28 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use snafu::prelude::*; 2 | use tokio_tungstenite::tungstenite::Error as WsError; 3 | 4 | #[derive(Debug, Snafu)] 5 | #[snafu(visibility(pub(crate)))] 6 | pub enum Error { 7 | #[snafu(display("aria2: responsed error: {source}"))] 8 | Aria2 { source: crate::Aria2Error }, 9 | #[snafu(display("aria2: cannot parse value {value:?} as {to}"))] 10 | Parse { value: String, to: String }, 11 | #[snafu(display("aria2: websocket error: {source}"))] 12 | WebsocketIo { source: WsError }, 13 | #[snafu(display("aria2: json error: {source}"))] 14 | Json { source: serde_json::Error }, 15 | #[snafu(display("aria2: websocket closed: {message}"))] 16 | WebsocketClosed { message: String }, 17 | #[snafu(display("aria2: reconnect task timeout"))] 18 | ReconnectTaskTimeout { source: tokio::time::error::Elapsed }, 19 | } 20 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aria2-ws" 3 | version = "0.5.1" 4 | edition = "2021" 5 | description = "An aria2 websocket jsonrpc API with notification support" 6 | repository = "https://github.com/WOo0W/aria2-ws-rs" 7 | license = "MIT" 8 | keywords = ["aria2", "jsonrpc"] 9 | categories = ["api-bindings"] 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dev-dependencies] 13 | tokio = { version = "1", features = ["full"] } 14 | env_logger = "0.11" 15 | test-log = "0.2" 16 | 17 | [dependencies] 18 | tokio = { version = "1", features = ["sync", "time", "macros", "rt"] } 19 | serde = { version = "1", features = ["derive"] } 20 | serde_json = "1" 21 | tokio-tungstenite = "0.26" 22 | futures = "0.3" 23 | base64 = "0.22" 24 | snafu = "0.8" 25 | log = "0.4" 26 | serde_with = { version = "3", features = ["chrono"] } 27 | chrono = { version = "0.4", features = ["serde"] } 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 WOo0W 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use log::info; 2 | use serde::Serialize; 3 | use serde_json::{to_value, Value}; 4 | use snafu::ResultExt; 5 | 6 | use crate::{error, Result}; 7 | 8 | pub trait PushExt { 9 | fn push_some(&mut self, t: Option) -> Result<()>; 10 | 11 | fn push_else(&mut self, t: Option, v: Value) -> Result<()>; 12 | } 13 | 14 | impl PushExt for Vec { 15 | fn push_some(&mut self, t: Option) -> Result<()> { 16 | if let Some(t) = t { 17 | self.push(to_value(t).context(error::JsonSnafu)?); 18 | } 19 | Ok(()) 20 | } 21 | 22 | fn push_else(&mut self, t: Option, v: Value) -> Result<()> { 23 | if let Some(t) = t { 24 | self.push(to_value(t).context(error::JsonSnafu)?); 25 | } else { 26 | self.push(v); 27 | } 28 | Ok(()) 29 | } 30 | } 31 | 32 | /// Convert `Value` into `Vec` 33 | /// 34 | /// # Panics 35 | /// 36 | /// Panic if value is not of type `Value::Array` 37 | pub fn value_into_vec(value: Value) -> Vec { 38 | if let Value::Array(v) = value { 39 | return v; 40 | } 41 | panic!("value is not Value::Array"); 42 | } 43 | 44 | /// Print error if the result is an Err. 45 | pub fn print_error(res: std::result::Result<(), E>) 46 | where 47 | E: std::fmt::Display, 48 | { 49 | if let Err(err) = res { 50 | info!("{}", err); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/options.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use serde_json::{Map, Value}; 3 | use serde_with::{serde_as, skip_serializing_none, DisplayFromStr}; 4 | 5 | /// Regular options of aria2 download tasks. 6 | /// 7 | /// For more options, add them to `extra_options` field, which is Object in `serde_json`. 8 | /// 9 | /// You can find all options in 10 | #[serde_as] 11 | #[skip_serializing_none] 12 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] 13 | #[serde(rename_all = "kebab-case")] 14 | pub struct TaskOptions { 15 | pub header: Option>, 16 | 17 | #[serde_as(as = "Option")] 18 | pub split: Option, 19 | 20 | pub all_proxy: Option, 21 | 22 | pub dir: Option, 23 | 24 | pub out: Option, 25 | 26 | pub gid: Option, 27 | 28 | #[serde_as(as = "Option")] 29 | pub r#continue: Option, 30 | 31 | #[serde_as(as = "Option")] 32 | pub auto_file_renaming: Option, 33 | 34 | #[serde_as(as = "Option")] 35 | pub check_integrity: Option, 36 | 37 | /// Close connection if download speed is lower than or equal to this value(bytes per sec). 38 | /// 39 | /// 0 means aria2 does not have a lowest speed limit. 40 | /// 41 | /// You can append K or M (1K = 1024, 1M = 1024K). 42 | /// 43 | /// This option does not affect BitTorrent downloads. 44 | /// 45 | /// Default: 0 46 | pub lowest_speed_limit: Option, 47 | 48 | /// Set max download speed per each download in bytes/sec. 0 means unrestricted. 49 | /// 50 | /// You can append K or M (1K = 1024, 1M = 1024K). 51 | /// 52 | /// To limit the overall download speed, use --max-overall-download-limit option. 53 | /// 54 | /// Default: 0 55 | pub max_download_limit: Option, 56 | 57 | #[serde_as(as = "Option")] 58 | pub max_connection_per_server: Option, 59 | 60 | #[serde_as(as = "Option")] 61 | pub max_tries: Option, 62 | 63 | #[serde_as(as = "Option")] 64 | pub timeout: Option, 65 | 66 | #[serde(flatten)] 67 | pub extra_options: Map, 68 | } 69 | -------------------------------------------------------------------------------- /tests/test.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use aria2_ws::{Callbacks, Client, TaskOptions}; 4 | use futures::FutureExt; 5 | use serde_json::json; 6 | use test_log::test; 7 | use tokio::{ 8 | spawn, 9 | sync::{broadcast, Semaphore}, 10 | }; 11 | 12 | #[tokio::test] 13 | #[ignore] 14 | async fn drop_test() { 15 | Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 16 | .await 17 | .unwrap(); 18 | } 19 | 20 | #[test(tokio::test)] 21 | // #[ignore] 22 | async fn example() { 23 | let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 24 | .await 25 | .unwrap(); 26 | let options = TaskOptions { 27 | split: Some(2), 28 | header: Some(vec!["Referer: https://www.pixiv.net/".to_string()]), 29 | // Add extra options which are not included in TaskOptions. 30 | extra_options: json!({"max-download-limit": "100K"}) 31 | .as_object() 32 | .unwrap() 33 | .clone(), 34 | ..Default::default() 35 | }; 36 | 37 | let mut not = client.subscribe_notifications(); 38 | spawn(async move { 39 | loop { 40 | match not.recv().await { 41 | Ok(msg) => println!("Received notification {:?}", &msg), 42 | Err(broadcast::error::RecvError::Closed) => { 43 | println!("Notification channel closed"); 44 | break; 45 | } 46 | Err(broadcast::error::RecvError::Lagged(_)) => { 47 | println!("Notification channel lagged"); 48 | } 49 | } 50 | } 51 | }); 52 | 53 | // use `tokio::sync::Semaphore` to wait for all tasks to finish. 54 | let semaphore = Arc::new(Semaphore::new(0)); 55 | client 56 | .add_uri( 57 | vec![ 58 | "https://i.pximg.net/img-original/img/2020/05/15/06/56/03/81572512_p0.png" 59 | .to_string(), 60 | ], 61 | Some(options.clone()), 62 | None, 63 | Some(Callbacks { 64 | on_download_complete: Some({ 65 | let s = semaphore.clone(); 66 | async move { 67 | s.add_permits(1); 68 | println!("Task 1 completed!"); 69 | } 70 | .boxed() 71 | }), 72 | on_error: Some({ 73 | let s = semaphore.clone(); 74 | async move { 75 | s.add_permits(1); 76 | println!("Task 1 error!"); 77 | } 78 | .boxed() 79 | }), 80 | }), 81 | ) 82 | .await 83 | .unwrap(); 84 | 85 | // Will 404 86 | client 87 | .add_uri( 88 | vec![ 89 | "https://i.pximg.net/img-original/img/2022/01/05/23/32/16/95326322_p0.pngxxxx" 90 | .to_string(), 91 | ], 92 | Some(options.clone()), 93 | None, 94 | Some(Callbacks { 95 | on_download_complete: Some({ 96 | let s = semaphore.clone(); 97 | async move { 98 | s.add_permits(1); 99 | println!("Task 2 completed!"); 100 | } 101 | .boxed() 102 | }), 103 | on_error: Some({ 104 | let s = semaphore.clone(); 105 | async move { 106 | s.add_permits(1); 107 | println!("Task 2 error!"); 108 | } 109 | .boxed() 110 | }), 111 | }), 112 | ) 113 | .await 114 | .unwrap(); 115 | 116 | // Wait for 2 tasks to finish. 117 | let _ = semaphore.acquire_many(2).await.unwrap(); 118 | 119 | // Force shutdown aria2. 120 | // client.force_shutdown().await.unwrap(); 121 | } 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aria2-ws 2 | 3 | An aria2 websocket jsonrpc in Rust. 4 | 5 | Built with `tokio`. 6 | 7 | [Docs.rs](https://docs.rs/aria2_ws) 8 | 9 | [aria2 RPC docs](https://aria2.github.io/manual/en/html/aria2c.html#methods) 10 | 11 | [Changelog](./CHANGELOG.md) 12 | 13 | ## Features 14 | 15 | - Methods and typed responses 16 | - Auto reconnect 17 | - Ensures `on_download_complete` and `on_error` callback to be executed even after reconnected. 18 | - Notification subscription 19 | 20 | ## Example 21 | 22 | ```rust 23 | use std::sync::Arc; 24 | 25 | use aria2_ws::{Client, Callbacks, TaskOptions}; 26 | use futures::FutureExt; 27 | use serde_json::json; 28 | use tokio::{spawn, sync::Semaphore}; 29 | 30 | async fn example() { 31 | let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 32 | .await 33 | .unwrap(); 34 | let options = TaskOptions { 35 | split: Some(2), 36 | header: Some(vec!["Referer: https://www.pixiv.net/".to_string()]), 37 | // Add extra options which are not included in TaskOptions. 38 | extra_options: json!({"max-download-limit": "100K"}) 39 | .as_object() 40 | .unwrap() 41 | .clone(), 42 | ..Default::default() 43 | }; 44 | 45 | let mut not = client.subscribe_notifications(); 46 | spawn(async move { 47 | loop { 48 | match not.recv().await { 49 | Ok(msg) => println!("Received notification {:?}", &msg), 50 | Err(broadcast::error::RecvError::Closed) => { 51 | println!("Notification channel closed"); 52 | break; 53 | } 54 | Err(broadcast::error::RecvError::Lagged(_)) => { 55 | println!("Notification channel lagged"); 56 | } 57 | } 58 | } 59 | }); 60 | 61 | // use `tokio::sync::Semaphore` to wait for all tasks to finish. 62 | let semaphore = Arc::new(Semaphore::new(0)); 63 | client 64 | .add_uri( 65 | vec![ 66 | "https://i.pximg.net/img-original/img/2020/05/15/06/56/03/81572512_p0.png" 67 | .to_string(), 68 | ], 69 | Some(options.clone()), 70 | None, 71 | Some(Callbacks { 72 | on_download_complete: Some({ 73 | let s = semaphore.clone(); 74 | async move { 75 | s.add_permits(1); 76 | println!("Task 1 completed!"); 77 | } 78 | .boxed() 79 | }), 80 | on_error: Some({ 81 | let s = semaphore.clone(); 82 | async move { 83 | s.add_permits(1); 84 | println!("Task 1 error!"); 85 | } 86 | .boxed() 87 | }), 88 | }), 89 | ) 90 | .await 91 | .unwrap(); 92 | 93 | // Will 404 94 | client 95 | .add_uri( 96 | vec![ 97 | "https://i.pximg.net/img-original/img/2022/01/05/23/32/16/95326322_p0.pngxxxx" 98 | .to_string(), 99 | ], 100 | Some(options.clone()), 101 | None, 102 | Some(Callbacks { 103 | on_download_complete: Some({ 104 | let s = semaphore.clone(); 105 | async move { 106 | s.add_permits(1); 107 | println!("Task 2 completed!"); 108 | } 109 | .boxed() 110 | }), 111 | on_error: Some({ 112 | let s = semaphore.clone(); 113 | async move { 114 | s.add_permits(1); 115 | println!("Task 2 error!"); 116 | } 117 | .boxed() 118 | }), 119 | }), 120 | ) 121 | .await 122 | .unwrap(); 123 | 124 | // Wait for 2 tasks to finish. 125 | let _ = semaphore.acquire_many(2).await.unwrap(); 126 | 127 | // Force shutdown aria2. 128 | // client.force_shutdown().await.unwrap(); 129 | } 130 | 131 | ``` 132 | -------------------------------------------------------------------------------- /tests/method.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use aria2_ws::{Client, TaskOptions}; 4 | use serde::{Deserialize, Serialize}; 5 | use serde_json::json; 6 | use serde_with::{serde_as, skip_serializing_none}; 7 | use test_log::test; 8 | 9 | async fn test_global(c: &Client) { 10 | let r = c.get_version().await.unwrap(); 11 | println!("{:?}\n", r); 12 | 13 | let r = c.get_global_option().await.unwrap(); 14 | println!("{:?}\n", r); 15 | 16 | let r = c.get_global_stat().await.unwrap(); 17 | println!("{:?}\n", r); 18 | 19 | let r = c.get_global_option().await.unwrap(); 20 | println!("{:?}\n", r); 21 | 22 | let r = c.get_session_info().await.unwrap(); 23 | println!("{:?}\n", r); 24 | 25 | let r = c.tell_active().await.unwrap(); 26 | println!("{:?}\n", r); 27 | 28 | let r = c.tell_stopped(0, 100).await.unwrap(); 29 | println!("{:?}\n", r); 30 | 31 | let r = c.tell_waiting(0, 100).await.unwrap(); 32 | println!("{:?}\n", r); 33 | } 34 | 35 | async fn test_metadata(c: &Client, gid: &str) { 36 | let r = c.get_option(gid).await.unwrap(); 37 | println!("{r:?}\n"); 38 | 39 | let r = c.get_files(gid).await.unwrap(); 40 | println!("{r:?}\n"); 41 | 42 | // let r = c.get_peers(gid).await.unwrap(); 43 | // println!("{:?}\n", r); 44 | 45 | let r = c.get_servers(gid).await.unwrap(); 46 | println!("{r:?}\n"); 47 | 48 | let r = c.get_uris(gid).await.unwrap(); 49 | println!("{r:?}\n"); 50 | 51 | let r = c.tell_status(gid).await.unwrap(); 52 | println!("{r:?}\n"); 53 | } 54 | 55 | async fn sleep(secs: u64) { 56 | tokio::time::sleep(Duration::from_secs(secs)).await; 57 | } 58 | 59 | #[test(tokio::test)] 60 | async fn global() { 61 | sleep(3).await; 62 | let c = Client::connect("ws://localhost:6800/jsonrpc", None) 63 | .await 64 | .unwrap(); 65 | test_global(&c).await; 66 | } 67 | 68 | #[test(tokio::test)] 69 | async fn torrent() { 70 | let c = Client::connect("ws://localhost:6800/jsonrpc", None) 71 | .await 72 | .unwrap(); 73 | 74 | let options = TaskOptions { 75 | max_download_limit: Some("100K".to_string()), 76 | extra_options: json!({ 77 | "file-allocation": "none", 78 | }) 79 | .as_object() 80 | .unwrap() 81 | .clone(), 82 | ..Default::default() 83 | }; 84 | 85 | let gid = c 86 | .add_uri(vec!["magnet:?xt=urn:btih:9b4c1489bfccd8205d152345f7a8aad52d9a1f57&dn=archlinux-2022.05.01-x86_64.iso".to_string()], Some(options), None, None) 87 | .await.unwrap(); 88 | 89 | sleep(10).await; 90 | 91 | // test_global(&c).await.unwrap(); 92 | test_metadata(&c, &gid).await; 93 | c.remove(&gid).await.unwrap(); 94 | c.remove_download_result(&gid).await.unwrap(); 95 | } 96 | 97 | #[test(tokio::test)] 98 | async fn http() { 99 | let c = Client::connect("ws://localhost:6800/jsonrpc", None) 100 | .await 101 | .unwrap(); 102 | 103 | let options = TaskOptions { 104 | max_download_limit: Some("100K".to_string()), 105 | extra_options: json!({ 106 | "file-allocation": "none", 107 | }) 108 | .as_object() 109 | .unwrap() 110 | .clone(), 111 | ..Default::default() 112 | }; 113 | 114 | let gid = c 115 | .add_uri( 116 | vec!["https://mirror.hoster.kz/archlinux/iso/latest/archlinux-x86_64.iso".to_string()], 117 | Some(options), 118 | None, 119 | None, 120 | ) 121 | .await 122 | .unwrap(); 123 | 124 | sleep(10).await; 125 | 126 | // test_global(&c).await.unwrap(); 127 | test_metadata(&c, &gid).await; 128 | c.remove(&gid).await.unwrap(); 129 | c.remove_download_result(&gid).await.unwrap(); 130 | } 131 | 132 | use serde_with::DisplayFromStr; 133 | #[serde_as] 134 | #[skip_serializing_none] 135 | #[derive(Deserialize, Serialize, Debug)] 136 | struct A { 137 | #[serde_as(as = "Option")] 138 | a: Option, 139 | b: (i32, i32), 140 | } 141 | 142 | #[test] 143 | fn serde_test() { 144 | let a = A { a: None, b: (1, 2) }; 145 | let j = serde_json::to_string(&a).unwrap(); 146 | println!("{}", j); 147 | let a = serde_json::from_str::(&j).unwrap(); 148 | println!("{:?}", a); 149 | } 150 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! An aria2 websocket jsonrpc in Rust. 3 | //! 4 | //! [aria2 RPC docs](https://aria2.github.io/manual/en/html/aria2c.html#methods) 5 | //! 6 | //! ## Features 7 | //! 8 | //! - Almost all methods and structed responses 9 | //! - Auto reconnect 10 | //! - Ensures `on_complete` and `on_error` hook to be executed even after reconnected. 11 | //! - Supports notifications 12 | //! 13 | //! ## Example 14 | //! 15 | //! ```no_run 16 | //! use std::sync::Arc; 17 | //! 18 | //! use aria2_ws::{Client, Callbacks, TaskOptions}; 19 | //! use futures::FutureExt; 20 | //! use serde_json::json; 21 | //! use tokio::{spawn, sync::Semaphore}; 22 | //! 23 | //! #[tokio::main] 24 | //! async fn main() { 25 | //! let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 26 | //! .await 27 | //! .unwrap(); 28 | //! let options = TaskOptions { 29 | //! split: Some(2), 30 | //! header: Some(vec!["Referer: https://www.pixiv.net/".to_string()]), 31 | //! all_proxy: Some("http://127.0.0.1:10809".to_string()), 32 | //! // Add extra options which are not included in TaskOptions. 33 | //! extra_options: json!({"max-download-limit": "200K"}) 34 | //! .as_object() 35 | //! .unwrap() 36 | //! .clone(), 37 | //! ..Default::default() 38 | //! }; 39 | //! 40 | //! // use `tokio::sync::Semaphore` to wait for all tasks to finish. 41 | //! let semaphore = Arc::new(Semaphore::new(0)); 42 | //! client 43 | //! .add_uri( 44 | //! vec![ 45 | //! "https://i.pximg.net/img-original/img/2020/05/15/06/56/03/81572512_p0.png" 46 | //! .to_string(), 47 | //! ], 48 | //! Some(options.clone()), 49 | //! None, 50 | //! Some(Callbacks { 51 | //! on_download_complete: Some({ 52 | //! let s = semaphore.clone(); 53 | //! async move { 54 | //! s.add_permits(1); 55 | //! println!("Task 1 completed!"); 56 | //! } 57 | //! .boxed() 58 | //! }), 59 | //! on_error: Some({ 60 | //! let s = semaphore.clone(); 61 | //! async move { 62 | //! s.add_permits(1); 63 | //! println!("Task 1 error!"); 64 | //! } 65 | //! .boxed() 66 | //! }), 67 | //! }), 68 | //! ) 69 | //! .await 70 | //! .unwrap(); 71 | //! 72 | //! // Will 404 73 | //! client 74 | //! .add_uri( 75 | //! vec![ 76 | //! "https://i.pximg.net/img-original/img/2022/01/05/23/32/16/95326322_p0.pngxxxx" 77 | //! .to_string(), 78 | //! ], 79 | //! Some(options.clone()), 80 | //! None, 81 | //! Some(Callbacks { 82 | //! on_download_complete: Some({ 83 | //! let s = semaphore.clone(); 84 | //! async move { 85 | //! s.add_permits(1); 86 | //! println!("Task 2 completed!"); 87 | //! } 88 | //! .boxed() 89 | //! }), 90 | //! on_error: Some({ 91 | //! let s = semaphore.clone(); 92 | //! async move { 93 | //! s.add_permits(1); 94 | //! println!("Task 2 error!"); 95 | //! } 96 | //! .boxed() 97 | //! }), 98 | //! }), 99 | //! ) 100 | //! .await 101 | //! .unwrap(); 102 | //! 103 | //! let mut not = client.subscribe_notifications(); 104 | //! 105 | //! spawn(async move { 106 | //! while let Ok(msg) = not.recv().await { 107 | //! println!("Received notification {:?}", &msg); 108 | //! } 109 | //! }); 110 | //! 111 | //! // Wait for 2 tasks to finish. 112 | //! let _ = semaphore.acquire_many(2).await.unwrap(); 113 | //! 114 | //! client.shutdown().await.unwrap(); 115 | //! } 116 | //! 117 | //! ``` 118 | 119 | mod callback; 120 | mod client; 121 | mod error; 122 | mod method; 123 | mod options; 124 | pub mod response; 125 | mod utils; 126 | 127 | pub use error::Error; 128 | pub use options::TaskOptions; 129 | // Re-export `Map` for `TaskOptions`. 130 | pub use callback::Callbacks; 131 | pub use client::{Client, InnerClient}; 132 | pub use serde_json::Map; 133 | 134 | use serde::{Deserialize, Serialize}; 135 | use serde_json::Value; 136 | use snafu::OptionExt; 137 | 138 | pub(crate) type Result = std::result::Result; 139 | 140 | #[derive(Serialize, Deserialize, Debug, Clone)] 141 | pub struct RpcRequest { 142 | pub id: Option, 143 | pub jsonrpc: String, 144 | pub method: String, 145 | #[serde(default)] 146 | pub params: Vec, 147 | } 148 | 149 | /// Error returned by RPC calls. 150 | #[derive(Serialize, Deserialize, Debug, Clone)] 151 | pub struct Aria2Error { 152 | pub code: i32, 153 | pub message: String, 154 | } 155 | impl std::fmt::Display for Aria2Error { 156 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 157 | write!( 158 | f, 159 | "aria2 responsed error: code {}: {}", 160 | self.code, self.message 161 | ) 162 | } 163 | } 164 | impl std::error::Error for Aria2Error {} 165 | 166 | #[derive(Deserialize, Debug, Clone)] 167 | pub struct RpcResponse { 168 | pub id: Option, 169 | pub jsonrpc: String, 170 | pub result: Option, 171 | pub error: Option, 172 | } 173 | 174 | /// Events about download progress from aria2. 175 | #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] 176 | pub enum Event { 177 | Start, 178 | Pause, 179 | Stop, 180 | Complete, 181 | Error, 182 | /// This notification will be sent when a torrent download is complete but seeding is still going on. 183 | BtComplete, 184 | } 185 | 186 | impl TryFrom<&str> for Event { 187 | type Error = crate::Error; 188 | 189 | fn try_from(value: &str) -> Result { 190 | use Event::*; 191 | let event = match value { 192 | "aria2.onDownloadStart" => Start, 193 | "aria2.onDownloadPause" => Pause, 194 | "aria2.onDownloadStop" => Stop, 195 | "aria2.onDownloadComplete" => Complete, 196 | "aria2.onDownloadError" => Error, 197 | "aria2.onBtDownloadComplete" => BtComplete, 198 | _ => return error::ParseSnafu { value, to: "Event" }.fail(), 199 | }; 200 | Ok(event) 201 | } 202 | } 203 | 204 | #[derive(Debug, Clone, PartialEq, Eq)] 205 | pub enum Notification { 206 | Aria2 { gid: String, event: Event }, 207 | WebSocketConnected, 208 | WebsocketClosed, 209 | } 210 | 211 | impl TryFrom<&RpcRequest> for Notification { 212 | type Error = crate::Error; 213 | 214 | fn try_from(req: &RpcRequest) -> Result { 215 | let gid = (|| req.params.get(0)?.get("gid")?.as_str())() 216 | .with_context(|| error::ParseSnafu { 217 | value: format!("{:?}", req), 218 | to: "Notification", 219 | })? 220 | .to_string(); 221 | let event = req.method.as_str().try_into()?; 222 | Ok(Notification::Aria2 { gid, event }) 223 | } 224 | } 225 | 226 | #[cfg(test)] 227 | mod tests { 228 | fn check_if_send() {} 229 | 230 | #[test] 231 | fn t() { 232 | check_if_send::(); 233 | check_if_send::(); 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /src/response.rs: -------------------------------------------------------------------------------- 1 | use chrono::{DateTime, Utc}; 2 | use serde::{Deserialize, Serialize}; 3 | use serde_with::{serde_as, DisplayFromStr, TimestampSeconds}; 4 | 5 | #[derive(Serialize, Deserialize, Debug, Clone)] 6 | #[serde(rename_all = "camelCase")] 7 | pub struct Version { 8 | pub enabled_features: Vec, 9 | 10 | pub version: String, 11 | } 12 | 13 | /// Full status of a task. 14 | /// 15 | /// 16 | #[serde_as] 17 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 18 | #[serde(rename_all = "camelCase")] 19 | pub struct Status { 20 | /// GID of the download. 21 | pub gid: String, 22 | 23 | pub status: TaskStatus, 24 | 25 | #[serde_as(as = "DisplayFromStr")] 26 | pub total_length: u64, 27 | 28 | #[serde_as(as = "DisplayFromStr")] 29 | pub completed_length: u64, 30 | 31 | #[serde_as(as = "DisplayFromStr")] 32 | pub upload_length: u64, 33 | 34 | /// Hexadecimal representation of the download progress. 35 | /// 36 | /// The highest bit corresponds to the piece at index 0. 37 | /// 38 | /// Any set bits indicate loaded pieces, while 39 | /// unset bits indicate not yet loaded and/or missing pieces. 40 | /// 41 | /// Any overflow bits at the end are set to zero. 42 | /// 43 | /// When the download was not started yet, 44 | /// this key will not be included in the response. 45 | pub bitfield: Option, 46 | 47 | #[serde_as(as = "DisplayFromStr")] 48 | pub download_speed: u64, 49 | 50 | #[serde_as(as = "DisplayFromStr")] 51 | pub upload_speed: u64, 52 | 53 | /// InfoHash. BitTorrent only 54 | pub info_hash: Option, 55 | 56 | #[serde_as(as = "Option")] 57 | pub num_seeders: Option, 58 | 59 | /// true if the local endpoint is a seeder. Otherwise false. BitTorrent only. 60 | #[serde_as(as = "Option")] 61 | pub seeder: Option, 62 | 63 | #[serde_as(as = "DisplayFromStr")] 64 | pub piece_length: u64, 65 | 66 | #[serde_as(as = "DisplayFromStr")] 67 | pub num_pieces: u64, 68 | 69 | #[serde_as(as = "DisplayFromStr")] 70 | pub connections: u64, 71 | 72 | pub error_code: Option, 73 | 74 | pub error_message: Option, 75 | /// List of GIDs which are generated as the result of this download. 76 | /// 77 | /// For example, when aria2 downloads a Metalink file, 78 | /// it generates downloads described in the Metalink (see the --follow-metalink option). 79 | /// 80 | /// This value is useful to track auto-generated downloads. 81 | /// 82 | /// If there are no such downloads, this key will not be included in the response. 83 | pub followed_by: Option>, 84 | 85 | /// The reverse link for followedBy. 86 | /// 87 | /// A download included in followedBy has this object's GID in its following value. 88 | pub following: Option, 89 | 90 | /// GID of a parent download. 91 | /// 92 | /// Some downloads are a part of another download. 93 | /// 94 | /// For example, if a file in a Metalink has BitTorrent resources, 95 | /// the downloads of ".torrent" files are parts of that parent. 96 | /// 97 | /// If this download has no parent, this key will not be included in the response. 98 | pub belongs_to: Option, 99 | 100 | pub dir: String, 101 | 102 | pub files: Vec, 103 | 104 | pub bittorrent: Option, 105 | 106 | /// The number of verified number of bytes while the files are being hash checked. 107 | /// 108 | /// This key exists only when this download is being hash checked. 109 | #[serde_as(as = "Option")] 110 | pub verified_length: Option, 111 | 112 | /// `true` if this download is waiting for the hash check in a queue. 113 | /// 114 | /// This key exists only when this download is in the queue. 115 | #[serde_as(as = "Option")] 116 | pub verify_integrity_pending: Option, 117 | } 118 | 119 | #[serde_as] 120 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 121 | #[serde(rename_all = "camelCase")] 122 | pub struct BittorrentStatus { 123 | pub announce_list: Vec>, 124 | 125 | pub comment: Option, 126 | 127 | #[serde_as(as = "Option>")] 128 | pub creation_date: Option>, 129 | 130 | pub mode: Option, 131 | } 132 | 133 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 134 | #[serde(rename_all = "lowercase")] 135 | pub enum BitTorrentFileMode { 136 | Single, 137 | Multi, 138 | } 139 | 140 | #[serde_as] 141 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 142 | #[serde(rename_all = "camelCase")] 143 | pub struct File { 144 | #[serde_as(as = "DisplayFromStr")] 145 | pub index: u64, 146 | 147 | pub path: String, 148 | 149 | #[serde_as(as = "DisplayFromStr")] 150 | pub length: u64, 151 | 152 | #[serde_as(as = "DisplayFromStr")] 153 | pub completed_length: u64, 154 | 155 | #[serde_as(as = "DisplayFromStr")] 156 | pub selected: bool, 157 | 158 | pub uris: Vec, 159 | } 160 | 161 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 162 | #[serde(rename_all = "camelCase")] 163 | pub struct Uri { 164 | pub status: UriStatus, 165 | 166 | pub uri: String, 167 | } 168 | 169 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 170 | #[serde(rename_all = "lowercase")] 171 | pub enum UriStatus { 172 | Used, 173 | Waiting, 174 | } 175 | 176 | /// Task status returned by `aria2.tellStatus`. 177 | /// 178 | /// `Active` for currently downloading/seeding downloads. 179 | /// 180 | /// `Waiting` for downloads in the queue; download is not started. 181 | /// 182 | /// `Paused` for paused downloads. 183 | /// 184 | /// `Error` for downloads that were stopped because of error. 185 | /// 186 | /// `Complete` for stopped and completed downloads. 187 | /// 188 | /// `Removed` for the downloads removed by user. 189 | /// 190 | /// 191 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 192 | #[serde(rename_all = "lowercase")] 193 | pub enum TaskStatus { 194 | Active, 195 | Waiting, 196 | Paused, 197 | Error, 198 | Complete, 199 | Removed, 200 | } 201 | 202 | #[serde_as] 203 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 204 | #[serde(rename_all = "camelCase")] 205 | pub struct Peer { 206 | #[serde_as(as = "DisplayFromStr")] 207 | pub am_choking: bool, 208 | 209 | pub bitfield: String, 210 | 211 | #[serde_as(as = "DisplayFromStr")] 212 | pub download_speed: u64, 213 | 214 | pub ip: String, 215 | 216 | #[serde_as(as = "DisplayFromStr")] 217 | pub peer_choking: bool, 218 | 219 | pub peer_id: String, 220 | 221 | #[serde_as(as = "DisplayFromStr")] 222 | pub port: u16, 223 | 224 | #[serde_as(as = "DisplayFromStr")] 225 | pub seeder: bool, 226 | 227 | #[serde_as(as = "DisplayFromStr")] 228 | pub upload_speed: u64, 229 | } 230 | 231 | #[serde_as] 232 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 233 | #[serde(rename_all = "camelCase")] 234 | pub struct GlobalStat { 235 | #[serde_as(as = "DisplayFromStr")] 236 | pub download_speed: u64, 237 | 238 | #[serde_as(as = "DisplayFromStr")] 239 | pub upload_speed: u64, 240 | 241 | #[serde_as(as = "DisplayFromStr")] 242 | pub num_active: i32, 243 | 244 | #[serde_as(as = "DisplayFromStr")] 245 | pub num_waiting: i32, 246 | 247 | #[serde_as(as = "DisplayFromStr")] 248 | pub num_stopped: i32, 249 | 250 | #[serde_as(as = "DisplayFromStr")] 251 | pub num_stopped_total: i32, 252 | } 253 | 254 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 255 | #[serde(rename_all = "camelCase")] 256 | pub struct SessionInfo { 257 | pub session_id: String, 258 | } 259 | 260 | #[serde_as] 261 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 262 | #[serde(rename_all = "camelCase")] 263 | pub struct GetServersResult { 264 | #[serde_as(as = "DisplayFromStr")] 265 | pub index: i32, 266 | 267 | pub servers: Vec, 268 | } 269 | 270 | #[serde_as] 271 | #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] 272 | #[serde(rename_all = "camelCase")] 273 | pub struct Server { 274 | pub uri: String, 275 | 276 | pub current_uri: String, 277 | 278 | #[serde_as(as = "DisplayFromStr")] 279 | pub download_speed: u64, 280 | } 281 | -------------------------------------------------------------------------------- /src/callback.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{hash_map::Entry, HashMap}, 3 | fmt, 4 | sync::Weak, 5 | time::Duration, 6 | }; 7 | 8 | use futures::future::BoxFuture; 9 | use log::{debug, info}; 10 | use serde::Deserialize; 11 | use snafu::ResultExt; 12 | use tokio::{ 13 | select, spawn, 14 | sync::{broadcast, mpsc}, 15 | time::timeout, 16 | }; 17 | 18 | use crate::{error, utils::print_error, Event, InnerClient, Notification, Result}; 19 | 20 | type Callback = Option>; 21 | 22 | /// Callbacks that will be executed on notifications. 23 | /// 24 | /// If the connection lost, all callbacks will be checked whether they need to be executed once reconnected. 25 | /// 26 | /// It executes at most once for each task. That means a task can either be completed or failed. 27 | /// 28 | /// If you need to customize the behavior, you can use `Client::subscribe_notifications` 29 | /// to receive notifications and handle them yourself, 30 | /// or use `tell_status` to check the status of the task. 31 | #[derive(Default)] 32 | pub struct Callbacks { 33 | /// Will trigger on `Event::Complete` or `Event::BtComplete`. 34 | pub on_download_complete: Callback, 35 | /// Will trigger on `Event::Error`. 36 | pub on_error: Callback, 37 | } 38 | 39 | impl Callbacks { 40 | pub(crate) fn is_empty(&self) -> bool { 41 | self.on_download_complete.is_none() && self.on_error.is_none() 42 | } 43 | } 44 | 45 | impl fmt::Debug for Callbacks { 46 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 | f.debug_struct("Callbacks") 48 | .field("on_download_complete", &self.on_download_complete.is_some()) 49 | .field("on_error", &self.on_error.is_some()) 50 | .finish() 51 | } 52 | } 53 | 54 | /// Check whether the callback is ready to be executed after reconnected. 55 | async fn on_reconnect( 56 | inner: &InnerClient, 57 | callbacks_map: &mut HashMap, 58 | ) -> Result<()> { 59 | // Response from `custom_tell_stopped` call 60 | #[derive(Debug, Clone, Deserialize)] 61 | #[serde(rename_all = "camelCase")] 62 | struct TaskStatus { 63 | status: String, 64 | total_length: String, 65 | completed_length: String, 66 | gid: String, 67 | } 68 | 69 | if callbacks_map.is_empty() { 70 | return Ok(()); 71 | } 72 | let mut tasks = HashMap::new(); 73 | let req = inner.custom_tell_stopped( 74 | 0, 75 | 1000, 76 | Some( 77 | ["status", "totalLength", "completedLength", "gid"] 78 | .into_iter() 79 | .map(|x| x.to_string()) 80 | .collect(), 81 | ), 82 | ); 83 | // Cancel if takes too long 84 | for map in timeout(Duration::from_secs(10), req) 85 | .await 86 | .context(error::ReconnectTaskTimeoutSnafu)?? 87 | { 88 | let task: TaskStatus = 89 | serde_json::from_value(serde_json::Value::Object(map)).context(error::JsonSnafu)?; 90 | tasks.insert(task.gid.clone(), task); 91 | } 92 | 93 | for (gid, callbacks) in callbacks_map { 94 | if let Some(status) = tasks.get(gid) { 95 | debug!("checking callbacks for gid {} after reconnected", gid); 96 | // Check if the task is finished by checking the length. 97 | if status.total_length == status.completed_length { 98 | if let Some(h) = callbacks.on_download_complete.take() { 99 | spawn(h); 100 | } 101 | } else if status.status == "error" { 102 | if let Some(h) = callbacks.on_error.take() { 103 | spawn(h); 104 | } 105 | } 106 | } 107 | } 108 | 109 | Ok(()) 110 | } 111 | 112 | fn invoke_callbacks_on_event(event: Event, callbacks: &mut Callbacks) -> bool { 113 | match event { 114 | Event::Complete | Event::BtComplete => { 115 | if let Some(callback) = callbacks.on_download_complete.take() { 116 | // Spawn a new task to avoid blocking the notification receiver. 117 | spawn(callback); 118 | } 119 | } 120 | Event::Error => { 121 | if let Some(callback) = callbacks.on_error.take() { 122 | spawn(callback); 123 | } 124 | } 125 | _ => return false, 126 | } 127 | true 128 | } 129 | 130 | #[derive(Debug)] 131 | pub(crate) struct TaskCallbacks { 132 | pub gid: String, 133 | pub callbacks: Callbacks, 134 | } 135 | 136 | pub(crate) async fn callback_worker( 137 | weak: Weak, 138 | mut rx_notification: broadcast::Receiver, 139 | mut rx_callback: mpsc::UnboundedReceiver, 140 | ) { 141 | use broadcast::error::RecvError; 142 | 143 | let mut is_first_notification = true; 144 | let mut callbacks_map = HashMap::new(); 145 | let mut yet_processed_notifications: HashMap> = HashMap::new(); 146 | 147 | loop { 148 | select! { 149 | r = rx_notification.recv() => { 150 | match r { 151 | Ok(notification) => { 152 | match notification { 153 | Notification::WebSocketConnected => { 154 | if is_first_notification { 155 | is_first_notification = false; 156 | continue; 157 | // Skip the first connected notification 158 | } 159 | // We might miss some notifications when the connection is lost. 160 | // So we need to check whether the callbacks need to be executed after reconnected. 161 | if let Some(inner) = weak.upgrade() { 162 | print_error(on_reconnect(inner.as_ref(), &mut callbacks_map).await); 163 | } 164 | }, 165 | Notification::Aria2 { gid, event } => { 166 | match callbacks_map.entry(gid.clone()) { 167 | Entry::Occupied(mut e) => { 168 | let invoked = invoke_callbacks_on_event(event, e.get_mut()); 169 | if invoked { 170 | e.remove(); 171 | } 172 | } 173 | _ => { 174 | // If the task is not in the map, we need to store it for possible later processing. 175 | yet_processed_notifications 176 | .entry(gid.clone()) 177 | .or_insert_with(Vec::new) 178 | .push(event); 179 | } 180 | } 181 | }, 182 | _ => {} 183 | } 184 | } 185 | Err(RecvError::Closed) => { 186 | return; 187 | } 188 | Err(RecvError::Lagged(_)) => { 189 | info!("unexpected lag in notifications"); 190 | } 191 | } 192 | }, 193 | r = rx_callback.recv() => { 194 | match r { 195 | Some(TaskCallbacks { gid, mut callbacks }) => { 196 | if let Some(events) = yet_processed_notifications.remove(&gid) { 197 | let mut invoked = false; 198 | for event in events { 199 | invoked = invoke_callbacks_on_event(event, &mut callbacks); 200 | if invoked { 201 | break; 202 | } 203 | } 204 | if !invoked { 205 | callbacks_map.insert(gid, callbacks); 206 | } 207 | } else { 208 | callbacks_map.insert(gid, callbacks); 209 | } 210 | } 211 | None => { 212 | return; 213 | } 214 | } 215 | }, 216 | } 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /src/method.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error, 3 | options::TaskOptions, 4 | response, 5 | utils::{value_into_vec, PushExt}, 6 | Callbacks, Client, InnerClient, Result, 7 | }; 8 | use base64::prelude::*; 9 | use serde::Serialize; 10 | use serde_json::{json, to_value, Map, Value}; 11 | use snafu::prelude::*; 12 | 13 | /// The parameter `how` in `changePosition`. 14 | /// 15 | /// 16 | #[derive(Serialize, Debug, Clone, PartialEq, Eq)] 17 | pub enum PositionHow { 18 | #[serde(rename = "POS_SET")] 19 | Set, 20 | #[serde(rename = "POS_CUR")] 21 | Cur, 22 | #[serde(rename = "POS_END")] 23 | End, 24 | } 25 | 26 | impl InnerClient { 27 | async fn custom_tell_multi( 28 | &self, 29 | method: &str, 30 | offset: i32, 31 | num: i32, 32 | keys: Option>, 33 | ) -> Result>> { 34 | let mut params = value_into_vec(json!([offset, num])); 35 | params.push_some(keys)?; 36 | self.call_and_wait(method, params).await 37 | } 38 | 39 | pub async fn get_version(&self) -> Result { 40 | self.call_and_wait("getVersion", vec![]).await 41 | } 42 | 43 | async fn returning_gid(&self, method: &str, gid: &str) -> Result<()> { 44 | self.call_and_wait::(method, vec![Value::String(gid.to_string())]) 45 | .await?; 46 | Ok(()) 47 | } 48 | 49 | pub async fn remove(&self, gid: &str) -> Result<()> { 50 | self.returning_gid("remove", gid).await 51 | } 52 | 53 | pub async fn force_remove(&self, gid: &str) -> Result<()> { 54 | self.returning_gid("forceRemove", gid).await 55 | } 56 | 57 | pub async fn pause(&self, gid: &str) -> Result<()> { 58 | self.returning_gid("pause", gid).await 59 | } 60 | 61 | pub async fn pause_all(&self) -> Result<()> { 62 | self.call_and_wait::("pauseAll", vec![]).await?; 63 | Ok(()) 64 | } 65 | 66 | pub async fn force_pause(&self, gid: &str) -> Result<()> { 67 | self.returning_gid("forcePause", gid).await 68 | } 69 | 70 | pub async fn force_pause_all(&self) -> Result<()> { 71 | self.call_and_wait::("forcePauseAll", vec![]) 72 | .await?; 73 | Ok(()) 74 | } 75 | 76 | pub async fn unpause(&self, gid: &str) -> Result<()> { 77 | self.returning_gid("unpause", gid).await 78 | } 79 | 80 | pub async fn unpause_all(&self) -> Result<()> { 81 | self.call_and_wait::("unpauseAll", vec![]).await?; 82 | Ok(()) 83 | } 84 | 85 | pub async fn custom_tell_status( 86 | &self, 87 | gid: &str, 88 | keys: Option>, 89 | ) -> Result> { 90 | let mut params = vec![Value::String(gid.to_string())]; 91 | params.push_some(keys)?; 92 | self.call_and_wait("tellStatus", params).await 93 | } 94 | 95 | pub async fn tell_status(&self, gid: &str) -> Result { 96 | self.call_and_wait("tellStatus", vec![Value::String(gid.to_string())]) 97 | .await 98 | } 99 | 100 | pub async fn get_uris(&self, gid: &str) -> Result> { 101 | self.call_and_wait("getUris", vec![Value::String(gid.to_string())]) 102 | .await 103 | } 104 | 105 | pub async fn get_files(&self, gid: &str) -> Result> { 106 | self.call_and_wait("getFiles", vec![Value::String(gid.to_string())]) 107 | .await 108 | } 109 | 110 | pub async fn get_peers(&self, gid: &str) -> Result> { 111 | self.call_and_wait("getPeers", vec![Value::String(gid.to_string())]) 112 | .await 113 | } 114 | 115 | pub async fn get_servers(&self, gid: &str) -> Result> { 116 | self.call_and_wait("getServers", vec![Value::String(gid.to_string())]) 117 | .await 118 | } 119 | 120 | pub async fn tell_active(&self) -> Result> { 121 | self.call_and_wait("tellActive", vec![]).await 122 | } 123 | 124 | pub async fn tell_waiting(&self, offset: i32, num: i32) -> Result> { 125 | self.call_and_wait("tellWaiting", value_into_vec(json!([offset, num]))) 126 | .await 127 | } 128 | 129 | pub async fn tell_stopped(&self, offset: i32, num: i32) -> Result> { 130 | self.call_and_wait("tellStopped", value_into_vec(json!([offset, num]))) 131 | .await 132 | } 133 | 134 | pub async fn custom_tell_active( 135 | &self, 136 | keys: Option>, 137 | ) -> Result>> { 138 | let mut params = Vec::new(); 139 | params.push_some(keys)?; 140 | self.call_and_wait("tellActive", params).await 141 | } 142 | 143 | pub async fn custom_tell_waiting( 144 | &self, 145 | offset: i32, 146 | num: i32, 147 | keys: Option>, 148 | ) -> Result>> { 149 | self.custom_tell_multi("tellWaiting", offset, num, keys) 150 | .await 151 | } 152 | 153 | pub async fn custom_tell_stopped( 154 | &self, 155 | offset: i32, 156 | num: i32, 157 | keys: Option>, 158 | ) -> Result>> { 159 | self.custom_tell_multi("tellStopped", offset, num, keys) 160 | .await 161 | } 162 | 163 | pub async fn change_position(&self, gid: &str, pos: i32, how: PositionHow) -> Result { 164 | let params = value_into_vec(json!([gid, pos, how])); 165 | self.call_and_wait("changePosition", params).await 166 | } 167 | 168 | /// # Returns 169 | /// This method returns a list which contains two integers. 170 | /// 171 | /// The first integer is the number of URIs deleted. 172 | /// The second integer is the number of URIs added. 173 | pub async fn change_uri( 174 | &self, 175 | gid: &str, 176 | file_index: i32, 177 | del_uris: Vec, 178 | add_uris: Vec, 179 | position: Option, 180 | ) -> Result<(i32, i32)> { 181 | let mut params = value_into_vec(json!([gid, file_index, del_uris, add_uris])); 182 | params.push_some(position)?; 183 | self.call_and_wait("changeUri", params).await 184 | } 185 | 186 | pub async fn get_option(&self, gid: &str) -> Result { 187 | self.call_and_wait("getOption", vec![Value::String(gid.to_string())]) 188 | .await 189 | } 190 | 191 | pub async fn change_option(&self, gid: &str, options: TaskOptions) -> Result<()> { 192 | self.call_and_wait( 193 | "changeOption", 194 | vec![ 195 | Value::String(gid.to_string()), 196 | to_value(options).context(error::JsonSnafu)?, 197 | ], 198 | ) 199 | .await 200 | } 201 | 202 | pub async fn get_global_option(&self) -> Result { 203 | self.call_and_wait("getGlobalOption", vec![]).await 204 | } 205 | 206 | pub async fn change_global_option(&self, options: TaskOptions) -> Result<()> { 207 | self.call_and_wait( 208 | "changeGlobalOption", 209 | vec![to_value(options).context(error::JsonSnafu)?], 210 | ) 211 | .await 212 | } 213 | 214 | pub async fn get_global_stat(&self) -> Result { 215 | self.call_and_wait("getGlobalStat", vec![]).await 216 | } 217 | 218 | pub async fn purge_download_result(&self) -> Result<()> { 219 | self.call_and_wait::("purgeDownloadResult", vec![]) 220 | .await?; 221 | Ok(()) 222 | } 223 | 224 | pub async fn remove_download_result(&self, gid: &str) -> Result<()> { 225 | self.call_and_wait::("removeDownloadResult", vec![Value::String(gid.to_string())]) 226 | .await?; 227 | Ok(()) 228 | } 229 | 230 | pub async fn get_session_info(&self) -> Result { 231 | self.call_and_wait("getSessionInfo", vec![]).await 232 | } 233 | 234 | pub async fn shutdown(&self) -> Result<()> { 235 | self.call_and_wait::("shutdown", vec![]).await?; 236 | Ok(()) 237 | } 238 | 239 | pub async fn force_shutdown(&self) -> Result<()> { 240 | self.call_and_wait::("forceShutdown", vec![]) 241 | .await?; 242 | Ok(()) 243 | } 244 | 245 | pub async fn save_session(&self) -> Result<()> { 246 | self.call_and_wait::("saveSession", vec![]).await?; 247 | Ok(()) 248 | } 249 | } 250 | 251 | impl Client { 252 | fn add_callbacks_option(&self, gid: &str, callbacks: Option) { 253 | if let Some(callbacks) = callbacks { 254 | self.add_callbacks(gid.to_string(), callbacks); 255 | } 256 | } 257 | 258 | pub async fn add_uri( 259 | &self, 260 | uris: Vec, 261 | options: Option, 262 | position: Option, 263 | callbacks: Option, 264 | ) -> Result { 265 | let mut params = vec![to_value(uris).context(error::JsonSnafu)?]; 266 | params.push_else(options, json!({}))?; 267 | params.push_some(position)?; 268 | 269 | let gid: String = self.call_and_wait("addUri", params).await?; 270 | self.add_callbacks_option(&gid, callbacks); 271 | Ok(gid) 272 | } 273 | 274 | pub async fn add_torrent( 275 | &self, 276 | torrent: impl AsRef<[u8]>, 277 | uris: Option>, 278 | options: Option, 279 | position: Option, 280 | callbacks: Option, 281 | ) -> Result { 282 | let mut params = vec![Value::String(BASE64_STANDARD.encode(torrent))]; 283 | params.push_else(uris, json!([]))?; 284 | params.push_else(options, json!({}))?; 285 | params.push_some(position)?; 286 | 287 | let gid: String = self.call_and_wait("addTorrent", params).await?; 288 | self.add_callbacks_option(&gid, callbacks); 289 | Ok(gid) 290 | } 291 | 292 | pub async fn add_metalink( 293 | &self, 294 | metalink: impl AsRef<[u8]>, 295 | options: Option, 296 | position: Option, 297 | callbacks: Option, 298 | ) -> Result { 299 | let mut params = vec![Value::String(BASE64_STANDARD.encode(metalink))]; 300 | params.push_else(options, json!({}))?; 301 | params.push_some(position)?; 302 | 303 | let gid: String = self.call_and_wait("addMetalink", params).await?; 304 | self.add_callbacks_option(&gid, callbacks); 305 | Ok(gid) 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | callback::{callback_worker, TaskCallbacks}, 3 | error, 4 | utils::print_error, 5 | Callbacks, Notification, Result, RpcRequest, RpcResponse, 6 | }; 7 | use futures::prelude::*; 8 | use log::{debug, info}; 9 | use serde::de::DeserializeOwned; 10 | use serde_json::Value; 11 | use snafu::prelude::*; 12 | use std::{ 13 | collections::HashMap, 14 | ops::Deref, 15 | sync::{ 16 | atomic::{AtomicI32, Ordering}, 17 | Arc, 18 | }, 19 | time::Duration, 20 | }; 21 | use tokio::{ 22 | select, spawn, 23 | sync::{broadcast, mpsc, oneshot, Notify}, 24 | time::sleep, 25 | }; 26 | use tokio_tungstenite::tungstenite::Message; 27 | type WebSocket = 28 | tokio_tungstenite::WebSocketStream>; 29 | 30 | #[derive(Debug)] 31 | pub(crate) struct Subscription { 32 | pub id: i32, 33 | pub tx: oneshot::Sender, 34 | } 35 | pub struct InnerClient { 36 | token: Option, 37 | id: AtomicI32, 38 | /// Channel for sending messages to the websocket. 39 | tx_ws_sink: mpsc::Sender, 40 | tx_notification: broadcast::Sender, 41 | tx_subscription: mpsc::Sender, 42 | /// On notified, all spawned tasks shut down. 43 | shutdown: Arc, 44 | } 45 | 46 | /// An aria2 websocket rpc client. 47 | /// 48 | /// # Example 49 | /// 50 | /// ``` 51 | /// use aria2_ws::Client; 52 | /// 53 | /// #[tokio::main] 54 | /// async fn main() { 55 | /// let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 56 | /// .await 57 | /// .unwrap(); 58 | /// let version = client.get_version().await.unwrap(); 59 | /// println!("{:?}", version); 60 | /// } 61 | /// ``` 62 | #[derive(Clone)] 63 | pub struct Client { 64 | inner: Arc, 65 | // The sender can be cloned like `Arc`. 66 | tx_callback: mpsc::UnboundedSender, 67 | } 68 | 69 | impl Drop for InnerClient { 70 | fn drop(&mut self) { 71 | // notify all spawned tasks to shutdown 72 | debug!("InnerClient dropped, notify shutdown"); 73 | self.shutdown.notify_waiters(); 74 | } 75 | } 76 | 77 | async fn process_ws( 78 | ws: WebSocket, 79 | rx_ws_sink: &mut mpsc::Receiver, 80 | tx_notification: broadcast::Sender, 81 | rx_subscription: &mut mpsc::Receiver, 82 | ) { 83 | let (mut sink, mut stream) = ws.split(); 84 | let mut subscriptions = HashMap::>::new(); 85 | 86 | let on_stream = |msg: String, 87 | subscriptions: &mut HashMap>| 88 | -> Result<()> { 89 | let v: Value = serde_json::from_str(&msg).context(error::JsonSnafu)?; 90 | if let Value::Object(obj) = &v { 91 | if obj.contains_key("method") { 92 | // The message should be a notification. 93 | // https://aria2.github.io/manual/en/html/aria2c.html#notifications 94 | let req: RpcRequest = serde_json::from_value(v).context(error::JsonSnafu)?; 95 | let notification = (&req).try_into()?; 96 | let _ = tx_notification.send(notification); 97 | return Ok(()); 98 | } 99 | } 100 | 101 | // The message should be a response. 102 | let res: RpcResponse = serde_json::from_value(v).context(error::JsonSnafu)?; 103 | if let Some(ref id) = res.id { 104 | let tx = subscriptions.remove(id); 105 | if let Some(tx) = tx { 106 | let _ = tx.send(res); 107 | } 108 | } 109 | Ok(()) 110 | }; 111 | 112 | loop { 113 | select! { 114 | msg = stream.try_next() => { 115 | debug!("websocket received message: {:?}", msg); 116 | let Ok(msg) = msg else { 117 | break; 118 | }; 119 | if let Some(Message::Text(s)) = msg { 120 | print_error(on_stream(s.to_string(), &mut subscriptions)); 121 | } 122 | }, 123 | msg = rx_ws_sink.recv() => { 124 | debug!("writing message to websocket: {:?}", msg); 125 | let Some(msg) = msg else { 126 | break; 127 | }; 128 | print_error(sink.send(msg).await); 129 | }, 130 | subscription = rx_subscription.recv() => { 131 | if let Some(subscription) = subscription { 132 | subscriptions.insert(subscription.id, subscription.tx); 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | impl InnerClient { 140 | pub(crate) async fn connect(url: &str, token: Option<&str>) -> Result { 141 | let (tx_ws_sink, mut rx_ws_sink) = mpsc::channel(1); 142 | let (tx_subscription, mut rx_subscription) = mpsc::channel(1); 143 | let shutdown = Arc::new(Notify::new()); 144 | // Broadcast notifications to all subscribers. 145 | // The receiver is dropped cause there is no subscriber for now. 146 | let (tx_notification, _) = broadcast::channel(1); 147 | 148 | let inner = InnerClient { 149 | tx_ws_sink, 150 | id: AtomicI32::new(0), 151 | token: token.map(|t| "token:".to_string() + t), 152 | tx_subscription, 153 | tx_notification: tx_notification.clone(), 154 | shutdown: shutdown.clone(), 155 | }; 156 | 157 | async fn connect_ws(url: &str) -> Result { 158 | debug!("connecting to {}", url); 159 | let (ws, res) = tokio_tungstenite::connect_async(url) 160 | .await 161 | .context(error::WebsocketIoSnafu)?; 162 | debug!("connected to {}, {:?}", url, res); 163 | Ok(ws) 164 | } 165 | 166 | let ws = connect_ws(url).await?; 167 | let url = url.to_string(); 168 | // spawn a task to process websocket messages 169 | spawn(async move { 170 | let mut ws = Some(ws); 171 | loop { 172 | if let Some(ws) = ws.take() { 173 | let _ = tx_notification.send(Notification::WebSocketConnected); 174 | 175 | let fut = process_ws( 176 | ws, 177 | &mut rx_ws_sink, 178 | tx_notification.clone(), 179 | &mut rx_subscription, 180 | ); 181 | 182 | select! { 183 | _ = fut => {}, 184 | _ = shutdown.notified() => { 185 | return; 186 | }, 187 | } 188 | 189 | let _ = tx_notification.send(Notification::WebsocketClosed); 190 | } else { 191 | let r = select! { 192 | r = connect_ws(&url) => r, 193 | _ = shutdown.notified() => return, 194 | }; 195 | match r { 196 | Ok(ws_) => { 197 | ws.replace(ws_); 198 | } 199 | Err(err) => { 200 | info!("{}", err); 201 | sleep(Duration::from_secs(3)).await; 202 | } 203 | } 204 | } 205 | } 206 | }); 207 | 208 | Ok(inner) 209 | } 210 | 211 | fn id(&self) -> i32 { 212 | self.id.fetch_add(1, Ordering::Relaxed) 213 | } 214 | 215 | async fn wait_for_id(&self, id: i32, rx: oneshot::Receiver) -> Result 216 | where 217 | T: DeserializeOwned + Send, 218 | { 219 | let res = rx.await.map_err(|err| { 220 | error::WebsocketClosedSnafu { 221 | message: format!("receiving response for id {}: {}", id, err), 222 | } 223 | .build() 224 | })?; 225 | 226 | if let Some(err) = res.error { 227 | return Err(err).context(error::Aria2Snafu); 228 | } 229 | 230 | if let Some(v) = res.result { 231 | Ok(serde_json::from_value::(v).context(error::JsonSnafu)?) 232 | } else { 233 | error::ParseSnafu { 234 | value: format!("{:?}", res), 235 | to: "RpcResponse with result", 236 | } 237 | .fail() 238 | } 239 | } 240 | 241 | /// Send a rpc request to websocket without waiting for response. 242 | pub async fn call(&self, id: i32, method: &str, mut params: Vec) -> Result<()> { 243 | if let Some(ref token) = self.token { 244 | params.insert(0, Value::String(token.clone())) 245 | } 246 | let req = RpcRequest { 247 | id: Some(id), 248 | jsonrpc: "2.0".to_string(), 249 | method: "aria2.".to_string() + method, 250 | params, 251 | }; 252 | self.tx_ws_sink 253 | .send(Message::Text( 254 | serde_json::to_string(&req) 255 | .context(error::JsonSnafu)? 256 | .into(), 257 | )) 258 | .await 259 | .expect("tx_ws_sink: receiver has been dropped"); 260 | Ok(()) 261 | } 262 | 263 | /// Send a rpc request to websocket and wait for corresponding response. 264 | pub async fn call_and_wait(&self, method: &str, params: Vec) -> Result 265 | where 266 | T: DeserializeOwned + Send, 267 | { 268 | let id = self.id(); 269 | let (tx, rx) = oneshot::channel(); 270 | self.tx_subscription 271 | .send(Subscription { id, tx }) 272 | .await 273 | .expect("tx_subscription: receiver has been closed"); 274 | 275 | self.call(id, method, params).await?; 276 | self.wait_for_id::(id, rx).await 277 | } 278 | 279 | /// Subscribe to notifications. 280 | /// 281 | /// Returns a instance of `broadcast::Receiver` which can be used to receive notifications. 282 | pub fn subscribe_notifications(&self) -> broadcast::Receiver { 283 | self.tx_notification.subscribe() 284 | } 285 | } 286 | 287 | impl Client { 288 | /// Create a new `Client` that connects to the given url. 289 | /// 290 | /// # Example 291 | /// 292 | /// ``` 293 | /// use aria2_ws::Client; 294 | /// 295 | /// #[tokio::main] 296 | /// async fn main() { 297 | /// let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None) 298 | /// .await 299 | /// .unwrap(); 300 | /// let gid = client 301 | /// .add_uri( 302 | /// vec!["https://go.dev/dl/go1.17.6.windows-amd64.msi".to_string()], 303 | /// None, 304 | /// None, 305 | /// None, 306 | /// ) 307 | /// .await 308 | /// .unwrap(); 309 | /// client.force_remove(&gid).await.unwrap(); 310 | /// } 311 | /// ``` 312 | pub async fn connect(url: &str, token: Option<&str>) -> Result { 313 | let inner = Arc::new(InnerClient::connect(url, token).await?); 314 | 315 | let weak = Arc::downgrade(&inner); 316 | let rx_notification = inner.subscribe_notifications(); 317 | let (tx_callback, rx_callback) = mpsc::unbounded_channel(); 318 | // hold a weak reference to `inner` to prevent not shutting down when `Client` is dropped 319 | spawn(callback_worker(weak, rx_notification, rx_callback)); 320 | 321 | Ok(Self { inner, tx_callback }) 322 | } 323 | 324 | pub(crate) fn add_callbacks(&self, gid: String, callbacks: Callbacks) { 325 | if callbacks.is_empty() { 326 | return; 327 | } 328 | self.tx_callback 329 | .send(TaskCallbacks { gid, callbacks }) 330 | .expect("tx_callback: receiver has been dropped"); 331 | } 332 | } 333 | 334 | impl Deref for Client { 335 | type Target = InnerClient; 336 | 337 | fn deref(&self) -> &Self::Target { 338 | &self.inner 339 | } 340 | } 341 | --------------------------------------------------------------------------------