() {
57 | return no_vapid();
58 | }
59 | return generic_error();
60 | }
61 | };
62 | let qr = qrcode::url_to_svg_qr(&url);
63 |
64 | if airgapped {
65 | index!(format!(
66 | r#"
67 | ⚠️This will configure your server in air gapped mode⚠️
68 | Molly won't be able to update push information if necessary.
69 | You can also keep a screenshot of this QR code in case you need to reconfigure your server without having access to it.
70 |
{intro}
71 | {url}
72 |
73 |
74 | {qr}
75 |
76 | Wish to use with the webserver?
77 | "#,
78 | ))
79 | } else {
80 | index!(format!(
81 | r#"
82 | {intro}
83 | {url}
84 |
85 | {qr}
86 |
87 | Wish to use in airgapped mode?
88 | "#,
89 | ))
90 | }
91 | }
92 |
93 | fn no_vapid() -> String {
94 | index!("VAPID Key not found. Configure a VAPID key and try again.
")
95 | }
96 |
97 | fn no_url() -> String {
98 | index!("URL not found. The request seems to be incorrectly formatted.
")
99 | }
100 |
101 | fn generic_error() -> String {
102 | index!("An error occurred. You should check the server logs.
")
103 | }
104 |
--------------------------------------------------------------------------------
/src/utils.rs:
--------------------------------------------------------------------------------
1 | use eyre::Result;
2 | use rocket::serde::json::json;
3 | use url::Url;
4 |
5 | pub mod post_allowed;
6 |
7 | pub fn anonymize_url(url_in: &str) -> String {
8 | let mut mut_url = url::Url::parse(url_in).unwrap();
9 | mut_url.set_host(Some("fake.domain.tld")).unwrap();
10 | mut_url.into()
11 | }
12 |
13 | pub async fn ping(url: Url) -> Result {
14 | let res = post_allowed::post_allowed(url, &json!({"test":true}), Some("test")).await?;
15 | res.error_for_status_ref()?;
16 | Ok(res)
17 | }
18 |
--------------------------------------------------------------------------------
/src/utils/post_allowed.rs:
--------------------------------------------------------------------------------
1 | use async_trait::async_trait;
2 | use eyre::{eyre, Result};
3 | use lazy_static::lazy_static;
4 | use reqwest::dns::Addrs;
5 | use reqwest::{dns::Resolve, redirect::Policy};
6 | use serde::Serialize;
7 | use std::net;
8 | use std::{
9 | fmt::{Display, Formatter},
10 | iter,
11 | net::{IpAddr, Ipv4Addr, SocketAddr},
12 | sync::Arc,
13 | };
14 | use trust_dns_resolver::{lookup_ip::LookupIp, TokioAsyncResolver};
15 | use url::{Host, Url};
16 |
17 | use crate::{config, vapid};
18 |
19 | lazy_static! {
20 | static ref RESOLVER: TokioAsyncResolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
21 | }
22 |
23 | #[derive(Debug)]
24 | enum Error {
25 | SchemeNotAllowed,
26 | HostNotAllowed,
27 | }
28 |
29 | impl Display for Error {
30 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31 | write!(f, "{:?}", self)
32 | }
33 | }
34 |
35 | impl std::error::Error for Error {}
36 |
37 | struct ResolveNothing;
38 |
39 | impl Resolve for ResolveNothing {
40 | fn resolve(&self, _: reqwest::dns::Name) -> reqwest::dns::Resolving {
41 | let addrs = Box::new(iter::once(net::SocketAddr::new(
42 | IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
43 | 0,
44 | ))) as Addrs;
45 | Box::pin(futures_util::future::ready(Ok(addrs)))
46 | }
47 | }
48 |
49 | pub async fn post_allowed(
50 | url: Url,
51 | body: &T,
52 | topic: Option<&str>,
53 | ) -> Result {
54 | let port = match url.port() {
55 | Some(p) => p,
56 | None if url.scheme() == "http" => 80,
57 | None if url.scheme() == "https" => 443,
58 | _ => return Err(eyre!(Error::SchemeNotAllowed)),
59 | };
60 |
61 | let client = if config::is_endpoint_allowed_by_user(&url) {
62 | reqwest::ClientBuilder::new().redirect(Policy::none())
63 | } else {
64 | let resolved_socket_addrs = url
65 | .resolve_allowed()
66 | .await?
67 | .into_iter()
68 | .map(|ip| SocketAddr::new(ip, port))
69 | .collect::>();
70 |
71 | if resolved_socket_addrs.is_empty() {
72 | log::info!(
73 | "Ignoring request to {}: no allowed ip",
74 | url.host_str().unwrap_or("No host")
75 | );
76 | return Err(eyre!(Error::HostNotAllowed));
77 | }
78 |
79 | reqwest::ClientBuilder::new()
80 | .redirect(Policy::none())
81 | .dns_resolver(Arc::new(ResolveNothing))
82 | .resolve_to_addrs(url.host_str().unwrap(), &resolved_socket_addrs)
83 | }
84 | .build()
85 | .unwrap();
86 |
87 | // That's OK to generate a new VAPID header for each request
88 | // It doesn't do too many calculations, and we push at most once per seconde.
89 | let vapid = vapid::get_vapid_header(url.origin()).ok();
90 |
91 | let mut builder = client
92 | .post(url)
93 | .header("TTL", "2592000") // 30 days
94 | .header("Content-Encoding", "aes128gcm") // Fake this encoding to be web push compliant
95 | .header("Urgency", "high");
96 | builder = if let Some(topic) = topic {
97 | builder.header("Topic", topic) // Should override previous push messages with same topic
98 | } else {
99 | builder
100 | };
101 | builder = if let Some(vapid) = vapid {
102 | builder.header("Authorization", vapid)
103 | } else {
104 | builder
105 | };
106 | Ok(builder.json(&body).send().await?)
107 | }
108 |
109 | #[async_trait]
110 | pub trait ResolveAllowed {
111 | async fn resolve_allowed(&self) -> Result>;
112 | }
113 |
114 | #[async_trait]
115 | impl ResolveAllowed for Url {
116 | async fn resolve_allowed(&self) -> Result> {
117 | if ["http", "https"].contains(&self.scheme()) {
118 | self.host()
119 | .ok_or(Error::HostNotAllowed)?
120 | .resolve_allowed()
121 | .await
122 | } else {
123 | Err(eyre!(Error::SchemeNotAllowed))
124 | }
125 | }
126 | }
127 |
128 | #[async_trait]
129 | impl ResolveAllowed for Host<&str> {
130 | async fn resolve_allowed(&self) -> Result> {
131 | match self {
132 | Host::Domain(d) => {
133 | RESOLVER
134 | .lookup_ip(*d)
135 | .await
136 | .map_err(|_| Error::HostNotAllowed)?
137 | .resolve_allowed()
138 | .await
139 | }
140 | Host::Ipv4(ip) if ip_rfc::global_v4(ip) => Ok(vec![IpAddr::V4(*ip)]),
141 | Host::Ipv6(ip) if ip_rfc::global_v6(ip) => Ok(vec![IpAddr::V6(*ip)]),
142 | _ => Err(eyre!(Error::HostNotAllowed)),
143 | }
144 | }
145 | }
146 |
147 | #[async_trait]
148 | impl ResolveAllowed for LookupIp {
149 | async fn resolve_allowed(&self) -> Result> {
150 | Ok(self.iter().filter(ip_rfc::global).collect())
151 | }
152 | }
153 |
154 | #[cfg(test)]
155 | mod tests {
156 | use rocket::serde::json::serde_json::json;
157 |
158 | use super::*;
159 | use std::str::FromStr;
160 |
161 | async fn len_from_str(url: &str) -> usize {
162 | Url::from_str(url)
163 | .unwrap()
164 | .resolve_allowed()
165 | .await
166 | .unwrap_or(vec![])
167 | .len()
168 | }
169 |
170 | #[tokio::test]
171 | async fn test_post() {
172 | config::load_config(None);
173 | post_allowed(
174 | Url::from_str("https://httpbin.org/post").unwrap(),
175 | &json!({"urgent": true}),
176 | None,
177 | )
178 | .await
179 | .unwrap();
180 | }
181 |
182 | /*
183 | #[tokio::test]
184 | async fn test_post_localhost() {
185 | env::set_var("MOLLY_ALLOWED_ENDPOINTS", "[\"http://127.0.0.1:8001\"]");
186 | env::set_var(
187 | "MOLLY_VAPID_PRIVKEY",
188 | "DSqYuWchrB6yIMYJtidvqANeRQic4uWy34afzZRsZnI",
189 | );
190 | config::load_config(None);
191 | post_allowed(
192 | Url::from_str("http://127.0.0.1:8001/test").unwrap(),
193 | &json!({"urgent": true}),
194 | None,
195 | )
196 | .await
197 | .unwrap();
198 | }*/
199 |
200 | #[tokio::test]
201 | async fn test_not_allowed() {
202 | config::load_config(None);
203 | assert_eq!(len_from_str("unix://signal.org").await, 0);
204 | assert_eq!(len_from_str("http://127.1").await, 0);
205 | assert_eq!(len_from_str("http://localhost").await, 0);
206 | assert_eq!(len_from_str("http://[::1]").await, 0);
207 | assert_eq!(len_from_str("http://10.10.1.1").await, 0);
208 | assert_eq!(len_from_str("http://[fc01::2]").await, 0);
209 | }
210 |
211 | #[tokio::test]
212 | async fn test_allowed() {
213 | config::load_config(None);
214 | assert!(len_from_str("http://signal.org").await.gt(&0));
215 | assert!(len_from_str("http://signal.org:8080").await.gt(&0));
216 | assert!(len_from_str("https://signal.org").await.gt(&0));
217 | assert!(len_from_str("http://18.244.114.115").await.gt(&0));
218 | assert!(
219 | len_from_str("http://[2600:9000:2550:ae00:13:5d53:5740:93a1]")
220 | .await
221 | .gt(&0)
222 | );
223 | }
224 | }
225 |
--------------------------------------------------------------------------------
/src/vapid.rs:
--------------------------------------------------------------------------------
1 | use std::{
2 | collections::HashMap,
3 | fmt::{Display, Formatter},
4 | ops::Add,
5 | sync::{Arc, Mutex},
6 | time::{Duration, Instant},
7 | };
8 |
9 | use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10 | use eyre::{eyre, Result};
11 | use jwt_simple::{
12 | self,
13 | algorithms::{ECDSAP256KeyPairLike, ECDSAP256PublicKeyLike, ES256KeyPair},
14 | claims::Claims,
15 | };
16 | use lazy_static::lazy_static;
17 | use openssl::{
18 | ec::{EcGroup, EcKey},
19 | nid::Nid,
20 | };
21 |
22 | use crate::config;
23 |
24 | lazy_static! {
25 | static ref KEY: Option = get_signer_from_conf().ok();
26 | /** Cache of VAPID keys */
27 | static ref VAPID_CACHE: Arc>> = Arc::new(Mutex::new(HashMap::new()));
28 | }
29 |
30 | const DURATION_VAPID: u64 = 4500; /* 1h15 */
31 | const DURATION_VAPID_CACHE: u64 = 3600; /* 1h */
32 |
33 | /**
34 | Wrapper containing the signer and the associated public key.
35 | */
36 | struct SignerWithPubKey {
37 | signer: ES256KeyPair,
38 | pubkey: String,
39 | }
40 |
41 | struct VapidCache {
42 | header: String,
43 | expire: Instant,
44 | }
45 |
46 | #[derive(Debug)]
47 | pub enum Error {
48 | VapidKeyError,
49 | }
50 |
51 | impl Display for Error {
52 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 | // We have a single kind of error: VapidKeyError
54 | write!(f, "VAPID key is probably missing. See https://github.com/mollyim/mollysocket?tab=readme-ov-file#vapid-key")
55 | }
56 | }
57 |
58 | impl std::error::Error for Error {}
59 |
60 | pub fn get_vapid_pubkey() -> Result<&'static str> {
61 | let key = KEY.as_ref().ok_or(Error::VapidKeyError)?;
62 | Ok(&key.pubkey)
63 | }
64 |
65 | /**
66 | Generate VAPID header for origin.
67 | */
68 | pub fn get_vapid_header(origin: url::Origin) -> Result {
69 | let key = KEY.as_ref().ok_or(Error::VapidKeyError)?;
70 | if let Some(h) = get_vapid_header_from_cache(&origin) {
71 | return Ok(h);
72 | }
73 | gen_vapid_header_with_key(origin, key)
74 | }
75 |
76 | /**
77 | Get VAPID header from cache if not expire
78 | */
79 | fn get_vapid_header_from_cache(origin: &url::Origin) -> Option {
80 | let origin_str = origin.unicode_serialization();
81 | let now = Instant::now();
82 | let cache = VAPID_CACHE.lock().unwrap();
83 | if let Some(c) = cache.get(&origin_str) {
84 | if c.expire > now {
85 | log::debug!("Found VAPID from cache");
86 | Some(c.header.clone())
87 | } else {
88 | log::debug!("VAPID from cache has expired");
89 | None
90 | }
91 | } else {
92 | None
93 | }
94 | }
95 |
96 | fn add_vapid_header_to_cache(origin_str: &str, header: &str) {
97 | let mut cache = VAPID_CACHE.lock().unwrap();
98 | cache.insert(
99 | origin_str.into(),
100 | VapidCache {
101 | header: header.into(),
102 | expire: Instant::now().add(Duration::from_secs(DURATION_VAPID_CACHE)),
103 | },
104 | );
105 | }
106 |
107 | fn gen_vapid_header_with_key(origin: url::Origin, key: &SignerWithPubKey) -> Result {
108 | let origin_str = origin.unicode_serialization();
109 | let claims = Claims::create(jwt_simple::prelude::Duration::from_secs(DURATION_VAPID))
110 | .with_audience(&origin_str)
111 | .with_subject("https://github.com/mollyim/mollysocket");
112 | let token = key.signer.sign(claims).unwrap();
113 |
114 | let header = format!("vapid t={},k={}", token.as_str(), &key.pubkey);
115 | add_vapid_header_to_cache(&origin_str, &header);
116 | Ok(header)
117 | }
118 |
119 | /**
120 | Get [SignerWithPubKey] from the config private key.
121 | */
122 | fn get_signer_from_conf() -> Result {
123 | match config::get_vapid_privkey() {
124 | Some(k) => get_signer(k),
125 | None => Err(eyre!(Error::VapidKeyError)),
126 | }
127 | }
128 |
129 | /**
130 | Get [SignerWithPubKey] from the private key.
131 | */
132 | fn get_signer(private_bytes: &str) -> Result {
133 | let private_key_bytes = URL_SAFE_NO_PAD.decode(private_bytes).unwrap();
134 | let size = private_key_bytes.len();
135 | if size != 32 {
136 | if size == 0 {
137 | log::warn!("No VAPID key was provided.")
138 | } else {
139 | log::warn!(
140 | "The private key has an unexpected size: {}, expected 32.",
141 | size
142 | )
143 | }
144 | return Err(eyre!(Error::VapidKeyError));
145 | }
146 | let kp = ES256KeyPair::from_bytes(&private_key_bytes).unwrap();
147 | let pubkey = URL_SAFE_NO_PAD.encode(kp.public_key().public_key().to_bytes_uncompressed());
148 |
149 | log::info!("VAPID public key: {:?}", pubkey);
150 | Ok(SignerWithPubKey { signer: kp, pubkey })
151 | }
152 |
153 | /**
154 | Generate a new VAPID key.
155 | */
156 | pub fn gen_vapid_key() -> String {
157 | let key = EcKey::generate(&EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap());
158 | URL_SAFE_NO_PAD.encode(key.unwrap().private_key().to_vec())
159 | }
160 |
161 | #[cfg(test)]
162 | mod tests {
163 |
164 | use super::*;
165 |
166 | const TEST_PRIVKEY: &str = "DSqYuWchrB6yIMYJtidvqANeRQic4uWy34afzZRsZnI";
167 | const TEST_PUBKEY: &str =
168 | "BOniQ9xHBPNY9gnQW4o-16vHqOb40pEIMifyUdFsxAgyzVkFMguxw0QrdbZcq8hRjN2zpeInRvKVPlkzABvuTnI";
169 |
170 | /**
171 | Test [get_signer] returns the right public key.
172 | */
173 | #[test]
174 | fn test_signer_pubkey() {
175 | assert_eq!(get_signer(TEST_PRIVKEY).unwrap().pubkey, (TEST_PUBKEY))
176 | }
177 |
178 | /**
179 | Test [gen_vapid_key] generate a key in the right format.
180 | */
181 | #[test]
182 | fn test_gen_vapid_key() {
183 | assert_eq!(get_signer(&gen_vapid_key()).unwrap().pubkey.len(), 87);
184 | }
185 |
186 | /**
187 | Test vapid with a wrong key
188 | */
189 | #[test]
190 | fn test_wrong_vapid() {
191 | assert!(get_signer(TEST_PUBKEY).is_err());
192 | assert!(get_signer("").is_err());
193 | }
194 |
195 | /**
196 | To verify the signature with another tool. This must be run with --nocapture:
197 | `cargo test vapid_other_tool -- -nocapture`
198 | */
199 | #[test]
200 | fn test_vapid_other_tool() {
201 | let signer = get_signer(&gen_vapid_key()).unwrap();
202 | let pubkey = signer.signer.public_key().to_pem().unwrap();
203 | let url = url::Url::parse("https://example.tld").unwrap();
204 | println!("PUB: \n{}", pubkey);
205 | println!(
206 | "header: {}",
207 | gen_vapid_header_with_key(url.origin(), &signer).unwrap()
208 | );
209 | }
210 |
211 | /* The following example depends on the config initialization
212 | /**
213 | Test vapid from conf
214 | */
215 | #[test]
216 | fn test_vapid_from_conf() {
217 | let key = gen_vapid_key();
218 | env::set_var("MOLLY_VAPID_PRIVKEY", &key);
219 | config::load_config(None);
220 | assert_eq!(
221 | get_signer_from_conf().unwrap().pubkey,
222 | get_signer(&key).unwrap().pubkey
223 | )
224 | }
225 |
226 | /**
227 | Test unset vapid from conf
228 | */
229 | //#[test]
230 | fn test_no_vapid_from_conf() {
231 | env::remove_var("MOLLY_VAPID_PRIVKEY");
232 | config::load_config(None);
233 | let res = match get_signer_from_conf() {
234 | Ok(_) => false,
235 | Err(_) => true,
236 | };
237 | assert_eq!(res, true);
238 | }
239 |
240 | */
241 | }
242 |
--------------------------------------------------------------------------------
/src/ws.rs:
--------------------------------------------------------------------------------
1 | mod proto_signalservice;
2 | mod proto_websocketresources;
3 | mod signalwebsocket;
4 | mod tls;
5 | mod websocket_connection;
6 |
7 | pub use signalwebsocket::SignalWebSocket;
8 |
--------------------------------------------------------------------------------
/src/ws/certs/signal-messenger.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIF2zCCA8OgAwIBAgIUAMHz4g60cIDBpPr1gyZ/JDaaPpcwDQYJKoZIhvcNAQEL
3 | BQAwdTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcT
4 | DU1vdW50YWluIFZpZXcxHjAcBgNVBAoTFVNpZ25hbCBNZXNzZW5nZXIsIExMQzEZ
5 | MBcGA1UEAxMQU2lnbmFsIE1lc3NlbmdlcjAeFw0yMjAxMjYwMDQ1NTFaFw0zMjAx
6 | MjQwMDQ1NTBaMHUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw
7 | FAYDVQQHEw1Nb3VudGFpbiBWaWV3MR4wHAYDVQQKExVTaWduYWwgTWVzc2VuZ2Vy
8 | LCBMTEMxGTAXBgNVBAMTEFNpZ25hbCBNZXNzZW5nZXIwggIiMA0GCSqGSIb3DQEB
9 | AQUAA4ICDwAwggIKAoICAQDEecifxMHHlDhxbERVdErOhGsLO08PUdNkATjZ1kT5
10 | 1uPf5JPiRbus9F4J/GgBQ4ANSAjIDZuFY0WOvG/i0qvxthpW70ocp8IjkiWTNiA8
11 | 1zQNQdCiWbGDU4B1sLi2o4JgJMweSkQFiyDynqWgHpw+KmvytCzRWnvrrptIfE4G
12 | PxNOsAtXFbVH++8JO42IaKRVlbfpe/lUHbjiYmIpQroZPGPY4Oql8KM3o39ObPnT
13 | o1WoM4moyOOZpU3lV1awftvWBx1sbTBL02sQWfHRxgNVF+Pj0fdDMMFdFJobArrL
14 | VfK2Ua+dYN4pV5XIxzVarSRW73CXqQ+2qloPW/ynpa3gRtYeGWV4jl7eD0PmeHpK
15 | OY78idP4H1jfAv0TAVeKpuB5ZFZ2szcySxrQa8d7FIf0kNJe9gIRjbQ+XrvnN+ZZ
16 | vj6d+8uBJq8LfQaFhlVfI0/aIdggScapR7w8oLpvdflUWqcTLeXVNLVrg15cEDwd
17 | lV8PVscT/KT0bfNzKI80qBq8LyRmauAqP0CDjayYGb2UAabnhefgmRY6aBE5mXxd
18 | byAEzzCS3vDxjeTD8v8nbDq+SD6lJi0i7jgwEfNDhe9XK50baK15Udc8Cr/ZlhGM
19 | jNmWqBd0jIpaZm1rzWA0k4VwXtDwpBXSz8oBFshiXs3FD6jHY2IhOR3ppbyd4qRU
20 | pwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNV
21 | HQ4EFgQUtfNLxuXWS9DlgGuMUMNnW7yx83EwHwYDVR0jBBgwFoAUtfNLxuXWS9Dl
22 | gGuMUMNnW7yx83EwDQYJKoZIhvcNAQELBQADggIBABUeiryS0qjykBN75aoHO9bV
23 | PrrX+DSJIB9V2YzkFVyh/io65QJMG8naWVGOSpVRwUwhZVKh3JVp/miPgzTGAo7z
24 | hrDIoXc+ih7orAMb19qol/2Ha8OZLa75LojJNRbZoCR5C+gM8C+spMLjFf9k3JVx
25 | dajhtRUcR0zYhwsBS7qZ5Me0d6gRXD0ZiSbadMMxSw6KfKk3ePmPb9gX+MRTS63c
26 | 8mLzVYB/3fe/bkpq4RUwzUHvoZf+SUD7NzSQRQQMfvAHlxk11TVNxScYPtxXDyiy
27 | 3Cssl9gWrrWqQ/omuHipoH62J7h8KAYbr6oEIq+Czuenc3eCIBGBBfvCpuFOgckA
28 | XXE4MlBasEU0MO66GrTCgMt9bAmSw3TrRP12+ZUFxYNtqWluRU8JWQ4FCCPcz9pg
29 | MRBOgn4lTxDZG+I47OKNuSRjFEP94cdgxd3H/5BK7WHUz1tAGQ4BgepSXgmjzifF
30 | T5FVTDTl3ZnWUVBXiHYtbOBgLiSIkbqGMCLtrBtFIeQ7RRTb3L+IE9R0UB0cJB3A
31 | Xbf1lVkOcmrdu2h8A32aCwtr5S1fBF1unlG7imPmqJfpOMWa8yIF/KWVm29JAPq8
32 | Lrsybb0z5gg8w7ZblEuB9zOW9M3l60DXuJO6l7g+deV6P96rv2unHS8UlvWiVWDy
33 | 9qfgAJizyy3kqM4lOwBH
34 | -----END CERTIFICATE-----
35 |
--------------------------------------------------------------------------------
/src/ws/proto_websocketresources.rs:
--------------------------------------------------------------------------------
1 | #[allow(clippy::derive_partial_eq_without_eq)]
2 | #[derive(Clone, PartialEq, ::prost::Message)]
3 | pub struct WebSocketRequestMessage {
4 | #[prost(string, optional, tag = "1")]
5 | pub verb: ::core::option::Option<::prost::alloc::string::String>,
6 | #[prost(string, optional, tag = "2")]
7 | pub path: ::core::option::Option<::prost::alloc::string::String>,
8 | #[prost(bytes = "vec", optional, tag = "3")]
9 | pub body: ::core::option::Option<::prost::alloc::vec::Vec>,
10 | #[prost(string, repeated, tag = "5")]
11 | pub headers: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
12 | #[prost(uint64, optional, tag = "4")]
13 | pub id: ::core::option::Option,
14 | }
15 | #[allow(clippy::derive_partial_eq_without_eq)]
16 | #[derive(Clone, PartialEq, ::prost::Message)]
17 | pub struct WebSocketResponseMessage {
18 | #[prost(uint64, optional, tag = "1")]
19 | pub id: ::core::option::Option,
20 | #[prost(uint32, optional, tag = "2")]
21 | pub status: ::core::option::Option,
22 | #[prost(string, optional, tag = "3")]
23 | pub message: ::core::option::Option<::prost::alloc::string::String>,
24 | #[prost(string, repeated, tag = "5")]
25 | pub headers: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
26 | #[prost(bytes = "vec", optional, tag = "4")]
27 | pub body: ::core::option::Option<::prost::alloc::vec::Vec>,
28 | }
29 | #[allow(clippy::derive_partial_eq_without_eq)]
30 | #[derive(Clone, PartialEq, ::prost::Message)]
31 | pub struct WebSocketMessage {
32 | #[prost(enumeration = "web_socket_message::Type", optional, tag = "1")]
33 | pub r#type: ::core::option::Option,
34 | #[prost(message, optional, tag = "2")]
35 | pub request: ::core::option::Option,
36 | #[prost(message, optional, tag = "3")]
37 | pub response: ::core::option::Option,
38 | }
39 | /// Nested message and enum types in `WebSocketMessage`.
40 | pub mod web_socket_message {
41 | #[derive(
42 | Clone,
43 | Copy,
44 | Debug,
45 | PartialEq,
46 | Eq,
47 | Hash,
48 | PartialOrd,
49 | Ord,
50 | ::prost::Enumeration
51 | )]
52 | #[repr(i32)]
53 | pub enum Type {
54 | Unknown = 0,
55 | Request = 1,
56 | Response = 2,
57 | }
58 | impl Type {
59 | /// String value of the enum field names used in the ProtoBuf definition.
60 | ///
61 | /// The values are not transformed in any way and thus are considered stable
62 | /// (if the ProtoBuf definition does not change) and safe for programmatic use.
63 | pub fn as_str_name(&self) -> &'static str {
64 | match self {
65 | Type::Unknown => "UNKNOWN",
66 | Type::Request => "REQUEST",
67 | Type::Response => "RESPONSE",
68 | }
69 | }
70 | /// Creates an enum from field names used in the ProtoBuf definition.
71 | pub fn from_str_name(value: &str) -> ::core::option::Option {
72 | match value {
73 | "UNKNOWN" => Some(Self::Unknown),
74 | "REQUEST" => Some(Self::Request),
75 | "RESPONSE" => Some(Self::Response),
76 | _ => None,
77 | }
78 | }
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/src/ws/signalwebsocket.rs:
--------------------------------------------------------------------------------
1 | use async_trait::async_trait;
2 | use eyre::Result;
3 | use futures_channel::mpsc;
4 | use prost::Message;
5 | use rocket::serde::json::serde_json::json;
6 | use std::{
7 | sync::{Arc, Mutex},
8 | time::{Duration, Instant},
9 | };
10 | use tokio::time;
11 | use tokio_tungstenite::tungstenite;
12 |
13 | use super::tls;
14 | use super::websocket_connection::WebSocketConnection;
15 | use super::{
16 | proto_signalservice::Envelope,
17 | proto_websocketresources::{
18 | web_socket_message::Type, WebSocketMessage, WebSocketRequestMessage,
19 | WebSocketResponseMessage,
20 | },
21 | };
22 | use crate::{config, utils::post_allowed::post_allowed};
23 |
24 | const PUSH_TIMEOUT: Duration = Duration::from_secs(1);
25 |
26 | #[derive(Debug)]
27 | pub struct Channels {
28 | ws_tx: Option>,
29 | pub on_message_tx: Option>,
30 | pub on_push_tx: Option>,
31 | pub on_reconnection_tx: Option>,
32 | }
33 |
34 | impl Channels {
35 | fn none() -> Self {
36 | Self {
37 | ws_tx: None,
38 | on_message_tx: None,
39 | on_push_tx: None,
40 | on_reconnection_tx: None,
41 | }
42 | }
43 | }
44 |
45 | #[derive(Debug)]
46 | pub struct SignalWebSocket {
47 | creds: String,
48 | push_endpoint: url::Url,
49 | pub channels: Channels,
50 | push_instant: Arc>,
51 | last_keepalive: Arc>,
52 | }
53 |
54 | #[async_trait(?Send)]
55 | impl WebSocketConnection for SignalWebSocket {
56 | fn get_url(&self) -> &str {
57 | &config::get_ws_endpoint()
58 | }
59 |
60 | fn get_creds(&self) -> &str {
61 | &self.creds
62 | }
63 |
64 | fn get_websocket_tx(&self) -> &Option> {
65 | &self.channels.ws_tx
66 | }
67 |
68 | fn set_websocket_tx(&mut self, tx: Option>) {
69 | self.channels.ws_tx = tx;
70 | }
71 |
72 | fn get_last_keepalive(&self) -> Arc> {
73 | Arc::clone(&self.last_keepalive)
74 | }
75 |
76 | async fn on_message(&self, message: WebSocketMessage) {
77 | if let Some(type_int) = message.r#type {
78 | if let Ok(type_) = Type::try_from(type_int) {
79 | match type_ {
80 | Type::Response => self.on_response(message.response),
81 | Type::Request => self.on_request(message.request).await,
82 | _ => (),
83 | };
84 | }
85 | }
86 | }
87 | }
88 |
89 | impl SignalWebSocket {
90 | pub fn new<'a, 'b: 'a>(
91 | uuid: &str,
92 | device_id: u32,
93 | password: &str,
94 | push_endpoint: &str,
95 | ) -> Result {
96 | let push_endpoint = url::Url::parse(&push_endpoint)?;
97 | Ok(Self {
98 | creds: format!("{}.{}:{}", uuid, device_id, password),
99 | push_endpoint,
100 | channels: Channels::none(),
101 | push_instant: Arc::new(Mutex::new(
102 | Instant::now().checked_sub(PUSH_TIMEOUT).unwrap(),
103 | )),
104 | last_keepalive: Arc::new(Mutex::new(Instant::now())),
105 | })
106 | }
107 |
108 | pub async fn connection_loop(&mut self) -> Result<()> {
109 | let mut count = 0;
110 | loop {
111 | let instant = Instant::now();
112 | {
113 | let mut keepalive = self.last_keepalive.lock().unwrap();
114 | *keepalive = Instant::now();
115 | }
116 | if let Err(e) = self.connect(tls::build_tls_connector()?).await {
117 | if let Some(tungstenite::Error::Http(resp)) = e.downcast_ref::()
118 | {
119 | if resp.status() == 403 {
120 | return Err(e);
121 | }
122 | }
123 | }
124 | if let Some(duration) = Instant::now().checked_duration_since(instant) {
125 | if duration > Duration::from_secs(60) {
126 | count = 0;
127 | }
128 | }
129 | if let Some(tx) = &self.channels.on_reconnection_tx {
130 | let _ = tx.unbounded_send(1);
131 | }
132 | count += 1;
133 | log::info!("Retrying to connect in {}0 seconds.", count);
134 | time::sleep(Duration::from_secs(count * 10)).await;
135 | }
136 | }
137 |
138 | fn on_response(&self, response: Option) {
139 | log::debug!("New response");
140 | if response.is_some() {
141 | let mut keepalive = self.last_keepalive.lock().unwrap();
142 | *keepalive = Instant::now();
143 | }
144 | }
145 |
146 | /**
147 | * That's when we must send a notification
148 | */
149 | async fn on_request(&self, request: Option) {
150 | log::debug!("New request");
151 | if let Some(request) = request {
152 | if let Some(envelope) = self.request_to_envelope(request).await {
153 | if let Some(tx) = &self.channels.on_message_tx {
154 | let _ = tx.unbounded_send(1);
155 | }
156 | if self.waiting_timeout_reached() {
157 | if envelope.urgent() {
158 | self.send_push().await;
159 | }
160 | } else {
161 | log::debug!("The waiting timeout is not reached: the request is ignored.");
162 | }
163 | }
164 | }
165 | }
166 |
167 | /**
168 | * Extract [`Envelope`] from [`request`] and send response to server.
169 | */
170 | async fn request_to_envelope(&self, request: WebSocketRequestMessage) -> Option {
171 | // dbg!(&request.path);
172 | let response = self.create_websocket_response(&request);
173 | // dbg!(&response);
174 | if self.is_signal_service_envelope(&request) {
175 | self.send_response(response).await;
176 | return match request.body {
177 | None => Some(Envelope {
178 | r#type: None,
179 | source_service_id: None,
180 | source_device: None,
181 | destination_service_id: None,
182 | timestamp: None,
183 | content: None,
184 | server_guid: None,
185 | server_timestamp: None,
186 | urgent: Some(false),
187 | updated_pni: None,
188 | story: None,
189 | reporting_token: None,
190 | }),
191 | Some(body) => Envelope::decode(&body[..]).ok(),
192 | };
193 | }
194 | None
195 | }
196 |
197 | fn is_signal_service_envelope(
198 | &self,
199 | WebSocketRequestMessage {
200 | verb,
201 | path,
202 | body: _,
203 | headers: _,
204 | id: _,
205 | }: &WebSocketRequestMessage,
206 | ) -> bool {
207 | if let Some(verb) = verb {
208 | if let Some(path) = path {
209 | return verb.eq("PUT") && path.eq("/api/v1/message");
210 | }
211 | }
212 | false
213 | }
214 |
215 | fn create_websocket_response(
216 | &self,
217 | request: &WebSocketRequestMessage,
218 | ) -> WebSocketResponseMessage {
219 | if self.is_signal_service_envelope(request) {
220 | return WebSocketResponseMessage {
221 | id: request.id,
222 | status: Some(200),
223 | message: Some(String::from("OK")),
224 | headers: Vec::new(),
225 | body: None,
226 | };
227 | }
228 | WebSocketResponseMessage {
229 | id: request.id,
230 | status: Some(400),
231 | message: Some(String::from("Unknown")),
232 | headers: Vec::new(),
233 | body: None,
234 | }
235 | }
236 |
237 | async fn send_push(&self) {
238 | log::debug!("Sending the notification.");
239 | {
240 | let mut instant = self.push_instant.lock().unwrap();
241 | *instant = Instant::now();
242 | }
243 |
244 | let url = self.push_endpoint.clone();
245 | let _ = post_allowed(url, &json!({"urgent": true}), Some("mollysocket")).await;
246 | if let Some(tx) = &self.channels.on_push_tx {
247 | let _ = tx.unbounded_send(1);
248 | }
249 | }
250 |
251 | fn waiting_timeout_reached(&self) -> bool {
252 | let instant = self.push_instant.lock().unwrap();
253 | instant.elapsed() > PUSH_TIMEOUT
254 | }
255 | }
256 |
--------------------------------------------------------------------------------
/src/ws/tls.rs:
--------------------------------------------------------------------------------
1 | use native_tls::{Certificate, TlsConnector};
2 |
3 | pub fn build_tls_connector() -> Result {
4 | let root_ca = include_bytes!("certs/signal-messenger.pem");
5 | let root_ca = Certificate::from_pem(root_ca).unwrap();
6 | let mut builder = TlsConnector::builder();
7 | builder.disable_built_in_roots(true);
8 | builder.add_root_certificate(root_ca);
9 | builder.build()
10 | }
11 |
12 | #[cfg(test)]
13 | mod tests {
14 | use std::net::TcpStream;
15 |
16 | use super::*;
17 |
18 | #[test]
19 | fn connect_trusted_server() {
20 | let builder = build_tls_connector().unwrap();
21 | let s = TcpStream::connect("chat.staging.signal.org:443").unwrap();
22 | builder.connect("chat.staging.signal.org", s).unwrap();
23 | }
24 |
25 | #[test]
26 | fn connect_untrusted_server() {
27 | let builder = build_tls_connector().unwrap();
28 | let s = TcpStream::connect("signal.org:443").unwrap();
29 | builder.connect("signal.org", s).unwrap_err();
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/ws/websocket_connection.rs:
--------------------------------------------------------------------------------
1 | use async_trait::async_trait;
2 | use base64::{prelude::BASE64_STANDARD, Engine};
3 | use eyre::Result;
4 | use futures_channel::mpsc;
5 | use futures_util::{pin_mut, select, FutureExt, SinkExt, StreamExt};
6 | use native_tls::TlsConnector;
7 | use prost::Message;
8 | use std::{
9 | sync::{Arc, Mutex},
10 | time::{Duration, Instant, SystemTime, UNIX_EPOCH},
11 | };
12 | use tokio::time;
13 | use tokio_tungstenite::{
14 | tungstenite::{self, ClientRequestBuilder},
15 | Connector::NativeTls,
16 | };
17 |
18 | use super::proto_websocketresources::{
19 | web_socket_message::Type, WebSocketMessage, WebSocketRequestMessage, WebSocketResponseMessage,
20 | };
21 |
22 | const KEEPALIVE: Duration = Duration::from_secs(30);
23 | const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(40);
24 |
25 | #[async_trait(?Send)]
26 | pub trait WebSocketConnection {
27 | fn get_url(&self) -> &str;
28 | /// return "login:password"
29 | fn get_creds(&self) -> &str;
30 | fn get_websocket_tx(&self) -> &Option>;
31 | fn set_websocket_tx(&mut self, tx: Option>);
32 | fn get_last_keepalive(&self) -> Arc>;
33 | async fn on_message(&self, message: WebSocketMessage);
34 |
35 | async fn connect(&mut self, tls_connector: TlsConnector) -> Result<()> {
36 | let request = ClientRequestBuilder::new(self.get_url().parse()?)
37 | .with_header("X-Signal-Agent", "\"OWA\"")
38 | .with_header(
39 | "Authorization",
40 | format!("Basic {}", BASE64_STANDARD.encode(self.get_creds())),
41 | );
42 |
43 | let (ws_stream, _) = tokio_tungstenite::connect_async_tls_with_config(
44 | request,
45 | None,
46 | false,
47 | Some(NativeTls(tls_connector)),
48 | )
49 | .await?;
50 |
51 | log::info!("WebSocket handshake has been successfully completed");
52 |
53 | // Websocket I/O
54 | let (ws_write, ws_read) = ws_stream.split();
55 | // channel to websocket ws_write
56 | let (tx, rx) = mpsc::unbounded();
57 | // other channels: msg, keepalive, abort
58 | let (timer_tx, timer_rx) = mpsc::unbounded();
59 |
60 | // Saving to socket Sender
61 | self.set_websocket_tx(Some(tx));
62 |
63 | // handlers
64 | let to_ws_handle = rx.map(Ok).forward(ws_write).fuse();
65 |
66 | let from_ws_handle = ws_read
67 | .for_each(|message| async {
68 | log::debug!("New message");
69 | if let Ok(message) = message {
70 | self.handle_message(message).await;
71 | }
72 | })
73 | .fuse();
74 |
75 | let from_keepalive_handle = timer_rx
76 | .for_each(|_| async { self.send_keepalive().await })
77 | .fuse();
78 |
79 | let to_keepalive_handle = self.loop_keepalive(timer_tx).fuse();
80 |
81 | pin_mut!(
82 | to_ws_handle,
83 | from_ws_handle,
84 | from_keepalive_handle,
85 | to_keepalive_handle
86 | );
87 |
88 | // handle websocket
89 | select!(
90 | _ = to_ws_handle => log::warn!("Messages finished"),
91 | _ = from_ws_handle => log::warn!("Websocket finished"),
92 | _ = from_keepalive_handle => log::warn!("Keepalive finished"),
93 | _ = to_keepalive_handle => log::warn!("Keepalive finished"),
94 | );
95 | Ok(())
96 | }
97 |
98 | async fn handle_message(&self, message: tungstenite::Message) {
99 | let data = message.into_data();
100 | let ws_message = match WebSocketMessage::decode(data) {
101 | Ok(msg) => msg,
102 | Err(e) => {
103 | log::error!("Failed to decode protobuf: {}", e);
104 | return;
105 | }
106 | };
107 | self.on_message(ws_message).await;
108 | }
109 |
110 | async fn send_message(&self, message: WebSocketMessage) {
111 | if let Some(mut tx) = self.get_websocket_tx().as_ref() {
112 | let bytes = message.encode_to_vec();
113 | tx.send(tungstenite::Message::binary(bytes)).await.unwrap();
114 | }
115 | }
116 |
117 | async fn send_response(&self, response: WebSocketResponseMessage) {
118 | let message = WebSocketMessage {
119 | r#type: Some(Type::Response as i32),
120 | response: Some(response),
121 | request: None,
122 | };
123 | self.send_message(message).await;
124 | }
125 |
126 | async fn send_keepalive(&self) {
127 | log::debug!("send_keepalive");
128 | let message = WebSocketMessage {
129 | r#type: Some(Type::Request as i32),
130 | response: None,
131 | request: Some(WebSocketRequestMessage {
132 | verb: Some(String::from("GET")),
133 | path: Some(String::from("/v1/keepalive")),
134 | body: None,
135 | headers: Vec::new(),
136 | id: Some(
137 | SystemTime::now()
138 | .duration_since(UNIX_EPOCH)
139 | .unwrap()
140 | .as_millis() as u64,
141 | ),
142 | }),
143 | };
144 | self.send_message(message).await;
145 | }
146 |
147 | async fn loop_keepalive(&self, timer_tx: mpsc::UnboundedSender) {
148 | // Get the ref of last_keepalive
149 | let last_keepalive = self.get_last_keepalive();
150 | loop {
151 | // read last_keepalive
152 | if last_keepalive.lock().unwrap().elapsed() > KEEPALIVE_TIMEOUT {
153 | log::warn!("Did not receive the last keepalive: aborting.");
154 | break;
155 | }
156 | time::sleep(KEEPALIVE).await;
157 | log::debug!("Sending Keepalive");
158 | timer_tx.unbounded_send(true).unwrap();
159 | }
160 | }
161 | }
162 |
--------------------------------------------------------------------------------