├── .gitignore ├── dist ├── arch │ ├── modbus-relay.sysusers │ ├── modbus-relay.service │ └── PKGBUILD └── debian │ ├── cargo-config.toml │ ├── package │ └── modbus-relay.service │ └── maintainer-scripts │ └── postinst ├── src ├── connection │ ├── stats │ │ ├── mod.rs │ │ ├── ip.rs │ │ ├── client.rs │ │ └── connection.rs │ ├── events.rs │ ├── guard.rs │ ├── backoff_strategy.rs │ ├── manager.rs │ └── mod.rs ├── config │ ├── types │ │ ├── mod.rs │ │ ├── stop_bits.rs │ │ ├── parity.rs │ │ ├── rts_type.rs │ │ └── data_bits.rs │ ├── mod.rs │ ├── tcp.rs │ ├── http.rs │ ├── stats.rs │ ├── backoff.rs │ ├── connection.rs │ ├── logging.rs │ ├── rtu.rs │ └── relay.rs ├── errors │ ├── init.rs │ ├── config.rs │ ├── backoff.rs │ ├── kinds │ │ ├── mod.rs │ │ ├── frame_size.rs │ │ ├── frame_format.rs │ │ ├── system_error.rs │ │ ├── client_error.rs │ │ ├── frame_error.rs │ │ ├── serial_error.rs │ │ └── protocol_error.rs │ ├── io_operation.rs │ ├── frame.rs │ ├── mod.rs │ ├── rts.rs │ ├── connection.rs │ ├── transport.rs │ └── relay.rs ├── utils.rs ├── lib.rs ├── main.rs ├── http_api.rs ├── stats_manager.rs ├── rtu_transport.rs ├── modbus.rs └── modbus_relay.rs ├── config ├── development.yaml ├── production.yaml ├── default.yaml └── config.example.yaml ├── Cross.toml ├── LICENSE-MIT ├── Cargo.toml ├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── README.md ├── TODO.md └── LICENSE-APACHE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | # Local configuration overrides 3 | config/local.yaml 4 | logs 5 | -------------------------------------------------------------------------------- /dist/arch/modbus-relay.sysusers: -------------------------------------------------------------------------------- 1 | # Create modbus-relay system user and add to uucp group 2 | u modbus-relay - "Modbus TCP to RTU relay service" /nonexistent /usr/sbin/nologin 3 | m modbus-relay uucp 4 | -------------------------------------------------------------------------------- /src/connection/stats/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | mod connection; 3 | mod ip; 4 | 5 | pub use client::Stats as ClientStats; 6 | pub use connection::Stats as ConnectionStats; 7 | pub use ip::Stats as IpStats; 8 | -------------------------------------------------------------------------------- /src/config/types/mod.rs: -------------------------------------------------------------------------------- 1 | mod data_bits; 2 | mod parity; 3 | mod rts_type; 4 | mod stop_bits; 5 | 6 | pub use data_bits::*; 7 | pub use parity::*; 8 | pub use rts_type::*; 9 | pub use stop_bits::*; 10 | -------------------------------------------------------------------------------- /dist/debian/cargo-config.toml: -------------------------------------------------------------------------------- 1 | [target.armv7-unknown-linux-gnueabihf] 2 | strip = { path = "arm-linux-gnueabihf-strip" } 3 | 4 | [target.aarch64-unknown-linux-gnu] 5 | strip = { path = "aarch64-linux-gnu-strip" } 6 | -------------------------------------------------------------------------------- /src/errors/init.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum InitializationError { 5 | #[error("Logging initialization error: {0}")] 6 | Logging(String), 7 | } 8 | 9 | impl InitializationError { 10 | pub fn logging(msg: impl Into) -> Self { 11 | Self::Logging(msg.into()) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/errors/config.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum ConfigValidationError { 5 | #[error("Configuration error: {0}")] 6 | Config(String), 7 | } 8 | 9 | impl ConfigValidationError { 10 | pub fn config(details: impl Into) -> Self { 11 | Self::Config(details.into()) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/errors/backoff.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum BackoffError { 5 | #[error("Maximum retries exceeded")] 6 | MaxRetriesExceeded, 7 | 8 | #[error("Invalid backoff configuration: {0}")] 9 | InvalidConfig(String), 10 | 11 | #[error("Retry attempt failed: {0}")] 12 | RetryFailed(String), 13 | } 14 | -------------------------------------------------------------------------------- /config/development.yaml: -------------------------------------------------------------------------------- 1 | # Development configuration for modbus-relay 2 | 3 | logging: 4 | # More verbose logging in development 5 | level: "debug" 6 | format: "pretty" 7 | 8 | rtu: 9 | # Common development device path 10 | device: "/dev/ttyAMA0" 11 | # Lower baud rate for testing 12 | baud_rate: 9600 13 | 14 | # Longer timeouts for debugging 15 | transaction_timeout: "10s" 16 | serial_timeout: "2s" 17 | -------------------------------------------------------------------------------- /src/connection/stats/ip.rs: -------------------------------------------------------------------------------- 1 | use std::time::SystemTime; 2 | 3 | use serde::Serialize; 4 | 5 | /// Stats for a single IP address 6 | #[derive(Debug, Clone, Serialize)] 7 | pub struct Stats { 8 | pub active_connections: usize, 9 | pub total_requests: u64, 10 | pub total_errors: u64, 11 | pub last_active: SystemTime, 12 | pub last_error: Option, 13 | pub avg_response_time_ms: u64, 14 | } 15 | -------------------------------------------------------------------------------- /dist/arch/modbus-relay.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=Modbus TCP to RTU relay service 3 | After=network.target 4 | Documentation=https://github.com/aljen/modbus-relay 5 | 6 | [Service] 7 | Type=simple 8 | User=modbus-relay 9 | Group=uucp 10 | ExecStart=/usr/bin/modbus-relay --config /etc/modbus-relay/config.yaml 11 | Restart=always 12 | RestartSec=5 13 | StandardOutput=journal 14 | StandardError=journal 15 | 16 | [Install] 17 | WantedBy=multi-user.target 18 | -------------------------------------------------------------------------------- /dist/debian/package/modbus-relay.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=Modbus TCP to RTU relay service 3 | After=network.target 4 | Documentation=https://github.com/aljen/modbus-relay 5 | 6 | [Service] 7 | Type=simple 8 | User=modbus-relay 9 | Group=dialout 10 | ExecStart=/usr/bin/modbus-relay --config /etc/modbus-relay/config.yaml 11 | Restart=always 12 | RestartSec=5 13 | StandardOutput=journal 14 | StandardError=journal 15 | 16 | [Install] 17 | WantedBy=multi-user.target 18 | -------------------------------------------------------------------------------- /src/errors/kinds/mod.rs: -------------------------------------------------------------------------------- 1 | mod client_error; 2 | mod frame_error; 3 | mod frame_format; 4 | mod frame_size; 5 | mod protocol_error; 6 | mod serial_error; 7 | mod system_error; 8 | 9 | pub use client_error::ClientErrorKind; 10 | pub use frame_error::FrameErrorKind; 11 | pub use frame_format::FrameFormatKind; 12 | pub use frame_size::FrameSizeKind; 13 | pub use protocol_error::ProtocolErrorKind; 14 | pub use serial_error::SerialErrorKind; 15 | pub use system_error::SystemErrorKind; 16 | -------------------------------------------------------------------------------- /dist/debian/maintainer-scripts/postinst: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # Create modbus-relay system user if it doesn't exist 5 | if ! getent passwd modbus-relay >/dev/null; then 6 | useradd --system --no-create-home \ 7 | --home-dir /nonexistent \ 8 | --shell /usr/sbin/nologin \ 9 | modbus-relay 10 | fi 11 | 12 | # Add modbus-relay user to dialout group for serial port access 13 | usermod -aG dialout modbus-relay || true 14 | 15 | #DEBHELPER# 16 | 17 | exit 0 18 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::{AtomicU64, Ordering}; 2 | 3 | static REQUEST_ID: AtomicU64 = AtomicU64::new(1); 4 | 5 | /// Generate a unique request ID 6 | pub fn generate_request_id() -> u64 { 7 | REQUEST_ID.fetch_add(1, Ordering::SeqCst) 8 | } 9 | 10 | #[cfg(test)] 11 | mod tests { 12 | use super::*; 13 | 14 | #[test] 15 | fn test_generate_request_id() { 16 | let id1 = generate_request_id(); 17 | let id2 = generate_request_id(); 18 | assert!(id2 > id1); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/errors/kinds/frame_size.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum FrameSizeKind { 3 | TooShort, 4 | TooLong, 5 | BufferOverflow, 6 | } 7 | 8 | impl std::fmt::Display for FrameSizeKind { 9 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 10 | match self { 11 | Self::TooShort => write!(f, "Frame too short"), 12 | Self::TooLong => write!(f, "Frame too long"), 13 | Self::BufferOverflow => write!(f, "Buffer overflow"), 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/config/mod.rs: -------------------------------------------------------------------------------- 1 | mod backoff; 2 | mod connection; 3 | mod http; 4 | mod logging; 5 | mod relay; 6 | mod rtu; 7 | mod stats; 8 | mod tcp; 9 | mod types; 10 | 11 | pub use backoff::Config as BackoffConfig; 12 | pub use connection::Config as ConnectionConfig; 13 | pub use http::Config as HttpConfig; 14 | pub use logging::Config as LoggingConfig; 15 | pub use relay::Config as RelayConfig; 16 | pub use rtu::Config as RtuConfig; 17 | pub use stats::Config as StatsConfig; 18 | pub use tcp::Config as TcpConfig; 19 | pub use types::{DataBits, Parity, RtsType, StopBits}; 20 | -------------------------------------------------------------------------------- /Cross.toml: -------------------------------------------------------------------------------- 1 | [build.env] 2 | passthrough = [ 3 | "PKG_CONFIG_ALLOW_CROSS", 4 | "PKG_CONFIG_PATH", 5 | "PKG_CONFIG_SYSROOT_DIR", 6 | "PKG_CONFIG_LIBDIR", 7 | ] 8 | 9 | [target.aarch64-unknown-linux-gnu] 10 | pre-build = [ 11 | "dpkg --add-architecture arm64", 12 | "apt-get update", 13 | "apt-get install --assume-yes libudev-dev:arm64 pkg-config", 14 | ] 15 | [target.armv7-unknown-linux-gnueabihf] 16 | pre-build = [ 17 | "dpkg --add-architecture armhf", 18 | "apt-get update", 19 | "apt-get install --assume-yes libudev-dev:armhf pkg-config", 20 | ] 21 | -------------------------------------------------------------------------------- /src/config/tcp.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | #[serde(deny_unknown_fields)] 7 | pub struct Config { 8 | pub bind_addr: String, 9 | pub bind_port: u16, 10 | #[serde(with = "humantime_serde")] 11 | pub keep_alive: Duration, 12 | } 13 | 14 | impl Default for Config { 15 | fn default() -> Self { 16 | Self { 17 | bind_addr: "0.0.0.0".to_string(), 18 | bind_port: 5000, 19 | keep_alive: Duration::from_secs(60), 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/errors/kinds/frame_format.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum FrameFormatKind { 3 | InvalidHeader, 4 | InvalidFormat, 5 | UnexpectedResponse, 6 | } 7 | 8 | impl std::fmt::Display for FrameFormatKind { 9 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 10 | match self { 11 | Self::InvalidHeader => write!(f, "Invalid frame header"), 12 | Self::InvalidFormat => write!(f, "Invalid frame format"), 13 | Self::UnexpectedResponse => write!(f, "Unexpected response"), 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/config/http.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Clone, Serialize, Deserialize)] 4 | #[serde(deny_unknown_fields)] 5 | pub struct Config { 6 | /// Enable HTTP API 7 | pub enabled: bool, 8 | /// HTTP server address 9 | pub bind_addr: String, 10 | /// HTTP server port 11 | pub bind_port: u16, 12 | /// Enable metrics collection 13 | pub metrics_enabled: bool, 14 | } 15 | 16 | impl Default for Config { 17 | fn default() -> Self { 18 | Self { 19 | enabled: true, 20 | bind_addr: "127.0.0.1".to_string(), 21 | bind_port: 8081, 22 | metrics_enabled: true, 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/errors/io_operation.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum IoOperation { 3 | Read, 4 | Write, 5 | Flush, 6 | Configure, 7 | Control, 8 | Listen, 9 | } 10 | 11 | impl std::fmt::Display for IoOperation { 12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 | match self { 14 | Self::Read => write!(f, "read"), 15 | Self::Write => write!(f, "write"), 16 | Self::Flush => write!(f, "flush"), 17 | Self::Configure => write!(f, "configure"), 18 | Self::Control => write!(f, "control"), 19 | Self::Listen => write!(f, "listen"), 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/errors/kinds/system_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum SystemErrorKind { 3 | ResourceAllocation, 4 | PermissionDenied, 5 | FileSystem, 6 | Network, 7 | Other, 8 | } 9 | 10 | impl std::fmt::Display for SystemErrorKind { 11 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 12 | match self { 13 | Self::ResourceAllocation => write!(f, "Resource allocation error"), 14 | Self::PermissionDenied => write!(f, "Permission denied"), 15 | Self::FileSystem => write!(f, "Filesystem error"), 16 | Self::Network => write!(f, "Network error"), 17 | Self::Other => write!(f, "Other system error"), 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/errors/frame.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use super::{FrameFormatKind, FrameSizeKind}; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum FrameError { 7 | #[error("Frame size error: {kind} - {details}")] 8 | Size { 9 | kind: FrameSizeKind, 10 | details: String, 11 | frame_data: Option>, 12 | }, 13 | 14 | #[error("Frame format error: {kind} - {details}")] 15 | Format { 16 | kind: FrameFormatKind, 17 | details: String, 18 | frame_data: Option>, 19 | }, 20 | 21 | #[error("CRC error: calculated={calculated:04X}, received={received:04X}, frame={frame_hex}")] 22 | Crc { 23 | calculated: u16, 24 | received: u16, 25 | frame_hex: String, 26 | }, 27 | } 28 | -------------------------------------------------------------------------------- /src/errors/mod.rs: -------------------------------------------------------------------------------- 1 | mod backoff; 2 | mod config; 3 | mod connection; 4 | mod frame; 5 | mod init; 6 | mod io_operation; 7 | mod kinds; 8 | mod relay; 9 | mod rts; 10 | mod transport; 11 | 12 | pub use kinds::ClientErrorKind; 13 | pub use kinds::FrameErrorKind; 14 | pub use kinds::FrameFormatKind; 15 | pub use kinds::FrameSizeKind; 16 | pub use kinds::ProtocolErrorKind; 17 | pub use kinds::SerialErrorKind; 18 | pub use kinds::SystemErrorKind; 19 | 20 | pub use backoff::BackoffError; 21 | pub use config::ConfigValidationError; 22 | pub use connection::ConnectionError; 23 | pub use frame::FrameError; 24 | pub use init::InitializationError; 25 | pub use io_operation::IoOperation; 26 | pub use relay::RelayError; 27 | pub use rts::RtsError; 28 | pub use transport::TransportError; 29 | -------------------------------------------------------------------------------- /config/production.yaml: -------------------------------------------------------------------------------- 1 | # Production configuration for modbus-relay 2 | 3 | tcp: 4 | # Listen on all interfaces in production 5 | bind_addr: "0.0.0.0" 6 | bind_port: 502 7 | 8 | rtu: 9 | # Common production device path 10 | device: "/dev/ttyAMA0" 11 | # Higher baud rate for production 12 | baud_rate: 115200 13 | # Enable RTS for better flow control 14 | rts_type: "down" 15 | rts_delay_us: 3500 16 | 17 | # Shorter timeouts in production 18 | transaction_timeout: "3s" 19 | serial_timeout: "500ms" 20 | # Larger frame size for better throughput 21 | max_frame_size: 512 22 | 23 | http: 24 | # Listen on localhost only 25 | bind_addr: "127.0.0.1" 26 | bind_port: 8080 27 | 28 | logging: 29 | # Less verbose logging in production 30 | level: "warn" 31 | format: "json" 32 | -------------------------------------------------------------------------------- /src/config/stats.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | #[serde(deny_unknown_fields)] 7 | pub struct Config { 8 | #[serde(with = "humantime_serde")] 9 | pub cleanup_interval: Duration, 10 | #[serde(with = "humantime_serde")] 11 | pub idle_timeout: Duration, 12 | #[serde(with = "humantime_serde")] 13 | pub error_timeout: Duration, 14 | pub max_events_per_second: u32, 15 | } 16 | 17 | impl Default for Config { 18 | fn default() -> Self { 19 | Self { 20 | cleanup_interval: Duration::from_secs(60), 21 | idle_timeout: Duration::from_secs(300), 22 | error_timeout: Duration::from_secs(300), 23 | max_events_per_second: 10000, 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/config/types/stop_bits.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] 4 | #[serde(rename_all = "lowercase")] 5 | pub enum StopBits { 6 | #[default] 7 | One, 8 | Two, 9 | } 10 | 11 | impl From for serialport::StopBits { 12 | fn from(stop_bits: StopBits) -> Self { 13 | match stop_bits { 14 | StopBits::One => serialport::StopBits::One, 15 | StopBits::Two => serialport::StopBits::Two, 16 | } 17 | } 18 | } 19 | 20 | impl std::fmt::Display for StopBits { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | match self { 23 | StopBits::One => write!(f, "1"), 24 | StopBits::Two => write!(f, "2"), 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/errors/rts.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum RtsError { 5 | #[error("Failed to set RTS signal: {0}")] 6 | SignalError(String), 7 | 8 | #[error("RTS timing error: {0}")] 9 | TimingError(String), 10 | 11 | #[error("RTS configuration error: {0}")] 12 | ConfigError(String), 13 | 14 | #[error("RTS system error: {0}")] 15 | SystemError(#[from] std::io::Error), 16 | } 17 | 18 | impl RtsError { 19 | pub fn signal(details: impl Into) -> Self { 20 | RtsError::SignalError(details.into()) 21 | } 22 | pub fn timing(details: impl Into) -> Self { 23 | RtsError::TimingError(details.into()) 24 | } 25 | pub fn config(details: impl Into) -> Self { 26 | RtsError::ConfigError(details.into()) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/errors/kinds/client_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum ClientErrorKind { 3 | ConnectionLost, 4 | Timeout, 5 | InvalidRequest, 6 | TooManyRequests, 7 | TooManyConnections, 8 | WriteError, 9 | } 10 | 11 | impl std::fmt::Display for ClientErrorKind { 12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 | match self { 14 | Self::ConnectionLost => write!(f, "Connection lost"), 15 | Self::Timeout => write!(f, "Timeout"), 16 | Self::InvalidRequest => write!(f, "Invalid request"), 17 | Self::TooManyRequests => write!(f, "Too many requests"), 18 | Self::TooManyConnections => write!(f, "Too many connections"), 19 | Self::WriteError => write!(f, "Write error"), 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/connection/events.rs: -------------------------------------------------------------------------------- 1 | use super::{ConnectionStats, stats::ClientStats}; 2 | use std::net::SocketAddr; 3 | use tokio::sync::oneshot; 4 | 5 | #[derive(Debug)] 6 | pub enum StatEvent { 7 | /// Client connected from address 8 | ClientConnected(SocketAddr), 9 | /// Client disconnected from address 10 | ClientDisconnected(SocketAddr), 11 | /// Request processed with success/failure and duration 12 | RequestProcessed { 13 | addr: SocketAddr, 14 | success: bool, 15 | duration_ms: u64, 16 | }, 17 | /// Query stats for specific address 18 | QueryStats { 19 | addr: SocketAddr, 20 | response_tx: oneshot::Sender, 21 | }, 22 | /// Query global connection stats 23 | QueryConnectionStats { 24 | response_tx: oneshot::Sender, 25 | }, 26 | } 27 | -------------------------------------------------------------------------------- /src/config/backoff.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | #[serde(deny_unknown_fields)] 7 | pub struct Config { 8 | /// Initial wait time 9 | #[serde(with = "humantime_serde")] 10 | pub initial_interval: Duration, 11 | /// Maximum wait time 12 | #[serde(with = "humantime_serde")] 13 | pub max_interval: Duration, 14 | /// Multiplier for each subsequent attempt 15 | pub multiplier: f64, 16 | /// Maximum number of attempts 17 | pub max_retries: u64, 18 | } 19 | 20 | impl Default for Config { 21 | fn default() -> Self { 22 | Self { 23 | initial_interval: Duration::from_millis(100), 24 | max_interval: Duration::from_secs(30), 25 | multiplier: 2.0, 26 | max_retries: 5, 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/config/types/parity.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] 4 | #[serde(rename_all = "lowercase")] 5 | pub enum Parity { 6 | #[default] 7 | None, 8 | Odd, 9 | Even, 10 | } 11 | 12 | impl From for serialport::Parity { 13 | fn from(parity: Parity) -> Self { 14 | match parity { 15 | Parity::None => serialport::Parity::None, 16 | Parity::Odd => serialport::Parity::Odd, 17 | Parity::Even => serialport::Parity::Even, 18 | } 19 | } 20 | } 21 | 22 | impl std::fmt::Display for Parity { 23 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 24 | match self { 25 | Parity::None => write!(f, "none"), 26 | Parity::Odd => write!(f, "odd"), 27 | Parity::Even => write!(f, "even"), 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/errors/kinds/frame_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum FrameErrorKind { 3 | TooShort, 4 | TooLong, 5 | InvalidFormat, 6 | InvalidUnitId, 7 | InvalidHeader, 8 | InvalidCrc, 9 | UnexpectedResponse, 10 | } 11 | 12 | impl std::fmt::Display for FrameErrorKind { 13 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 14 | match self { 15 | Self::TooShort => write!(f, "Frame too short"), 16 | Self::TooLong => write!(f, "Frame too long"), 17 | Self::InvalidFormat => write!(f, "Invalid frame format"), 18 | Self::InvalidUnitId => write!(f, "Invalid unit ID"), 19 | Self::InvalidHeader => write!(f, "Invalid frame header"), 20 | Self::InvalidCrc => write!(f, "Invalid frame CRC"), 21 | Self::UnexpectedResponse => write!(f, "Unexpected response"), 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/connection/stats/client.rs: -------------------------------------------------------------------------------- 1 | use std::time::SystemTime; 2 | 3 | use serde::Serialize; 4 | 5 | /// Stats for a single client 6 | #[derive(Debug, Clone, Serialize)] 7 | pub struct Stats { 8 | /// Number of active connections from this address 9 | pub active_connections: usize, 10 | /// Total number of requests 11 | pub total_requests: u64, 12 | /// Total number of errors 13 | pub total_errors: u64, 14 | /// Last activity 15 | pub last_active: SystemTime, 16 | /// Timestamp of the last error 17 | pub last_error: Option, 18 | /// Average response time 19 | pub avg_response_time_ms: u64, 20 | } 21 | 22 | impl Default for Stats { 23 | fn default() -> Self { 24 | Self { 25 | active_connections: 0, 26 | total_requests: 0, 27 | total_errors: 0, 28 | last_active: SystemTime::now(), 29 | last_error: None, 30 | avg_response_time_ms: 0, 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/config/types/rts_type.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] 4 | #[serde(rename_all = "lowercase")] 5 | pub enum RtsType { 6 | /// RTS disabled 7 | None, 8 | /// RTS = High during transmission 9 | Up, 10 | /// RTS = LOW during transmission 11 | #[default] 12 | Down, 13 | } 14 | 15 | impl RtsType { 16 | pub fn to_signal_level(&self, is_transmitting: bool) -> bool { 17 | match self { 18 | RtsType::None => false, 19 | RtsType::Up => is_transmitting, 20 | RtsType::Down => !is_transmitting, 21 | } 22 | } 23 | } 24 | 25 | impl std::fmt::Display for RtsType { 26 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 27 | match self { 28 | RtsType::None => write!(f, "none"), 29 | RtsType::Up => write!(f, "up"), 30 | RtsType::Down => write!(f, "down"), 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod config; 2 | pub mod connection; 3 | pub mod errors; 4 | pub mod http_api; 5 | pub mod modbus; 6 | pub mod modbus_relay; 7 | pub mod rtu_transport; 8 | pub mod stats_manager; 9 | mod utils; 10 | 11 | pub use config::{ 12 | ConnectionConfig, HttpConfig, LoggingConfig, RelayConfig, RtuConfig, StatsConfig, TcpConfig, 13 | }; 14 | pub use config::{DataBits, Parity, RtsType, StopBits}; 15 | pub use connection::BackoffStrategy; 16 | pub use connection::{ClientStats, ConnectionStats, IpStats}; 17 | pub use connection::{ConnectionGuard, ConnectionManager}; 18 | pub use errors::{ 19 | BackoffError, ClientErrorKind, ConfigValidationError, ConnectionError, FrameErrorKind, 20 | IoOperation, ProtocolErrorKind, RelayError, RtsError, SerialErrorKind, TransportError, 21 | }; 22 | pub use http_api::start_http_server; 23 | pub use modbus::{ModbusProcessor, guess_response_size}; 24 | pub use modbus_relay::ModbusRelay; 25 | pub use rtu_transport::RtuTransport; 26 | pub use stats_manager::StatsManager; 27 | -------------------------------------------------------------------------------- /src/connection/guard.rs: -------------------------------------------------------------------------------- 1 | use std::{net::SocketAddr, sync::Arc}; 2 | use tokio::sync::OwnedSemaphorePermit; 3 | use tracing::{trace, warn}; 4 | 5 | use crate::connection::StatEvent; 6 | 7 | use super::ConnectionManager; 8 | 9 | /// RAII guard for the connection 10 | #[derive(Debug)] 11 | pub struct ConnectionGuard { 12 | pub manager: Arc, 13 | pub addr: SocketAddr, 14 | pub _global_permit: OwnedSemaphorePermit, 15 | pub _per_ip_permit: Option, 16 | } 17 | 18 | impl Drop for ConnectionGuard { 19 | fn drop(&mut self) { 20 | trace!("Dropping connection guard for {}", self.addr); 21 | 22 | if let Err(e) = self 23 | .manager 24 | .stats_tx() 25 | .try_send(StatEvent::ClientDisconnected(self.addr)) 26 | { 27 | warn!("Failed to send disconnect event: {}", e); 28 | } 29 | 30 | self.manager.decrease_connection_count(self.addr); 31 | 32 | trace!("Connection guard dropped for {}", self.addr); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/errors/kinds/serial_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum SerialErrorKind { 3 | OpenFailed, 4 | ReadFailed, 5 | WriteFailed, 6 | ConfigurationFailed, 7 | Disconnected, 8 | BufferOverflow, 9 | ParityError, 10 | FramingError, 11 | } 12 | 13 | impl std::fmt::Display for SerialErrorKind { 14 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 15 | match self { 16 | Self::OpenFailed => write!(f, "Failed to open port"), 17 | Self::ReadFailed => write!(f, "Failed to read from port"), 18 | Self::WriteFailed => write!(f, "Failed to write to port"), 19 | Self::ConfigurationFailed => write!(f, "Failed to configure port"), 20 | Self::Disconnected => write!(f, "Port disconnected"), 21 | Self::BufferOverflow => write!(f, "Buffer overflow"), 22 | Self::ParityError => write!(f, "Parity error"), 23 | Self::FramingError => write!(f, "Framing error"), 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 Artur Wyszyński 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/config/types/data_bits.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] 4 | pub struct DataBits(u8); 5 | 6 | impl DataBits { 7 | pub fn new(bits: u8) -> Option { 8 | match bits { 9 | 5..=8 => Some(Self(bits)), 10 | _ => None, 11 | } 12 | } 13 | 14 | pub fn get(&self) -> u8 { 15 | self.0 16 | } 17 | } 18 | 19 | impl Default for DataBits { 20 | fn default() -> Self { 21 | Self(8) 22 | } 23 | } 24 | 25 | impl From for serialport::DataBits { 26 | fn from(data_bits: DataBits) -> Self { 27 | match data_bits.0 { 28 | 5 => serialport::DataBits::Five, 29 | 6 => serialport::DataBits::Six, 30 | 7 => serialport::DataBits::Seven, 31 | 8 => serialport::DataBits::Eight, 32 | _ => unreachable!("DataBits constructor ensures valid values"), 33 | } 34 | } 35 | } 36 | 37 | impl std::fmt::Display for DataBits { 38 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 39 | write!(f, "{}", self.0) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/errors/connection.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use super::BackoffError; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum ConnectionError { 7 | #[error("Connection limit exceeded: {0}")] 8 | LimitExceeded(String), 9 | 10 | #[error("Connection timed out: {0}")] 11 | Timeout(String), 12 | 13 | #[error("Invalid connection state: {0}")] 14 | InvalidState(String), 15 | 16 | #[error("Connection rejected: {0}")] 17 | Rejected(String), 18 | 19 | #[error("Connection disconnected")] 20 | Disconnected, 21 | 22 | #[error("Backoff error: {0}")] 23 | Backoff(#[from] BackoffError), 24 | } 25 | 26 | impl ConnectionError { 27 | pub fn limit_exceeded(details: impl Into) -> Self { 28 | ConnectionError::LimitExceeded(details.into()) 29 | } 30 | 31 | pub fn timeout(details: impl Into) -> Self { 32 | ConnectionError::Timeout(details.into()) 33 | } 34 | 35 | pub fn invalid_state(details: impl Into) -> Self { 36 | ConnectionError::InvalidState(details.into()) 37 | } 38 | 39 | pub fn rejected(details: impl Into) -> Self { 40 | ConnectionError::Rejected(details.into()) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/connection/backoff_strategy.rs: -------------------------------------------------------------------------------- 1 | use std::time::{Duration, Instant}; 2 | 3 | use crate::config::BackoffConfig; 4 | 5 | /// Helper for implementing backoff strategy 6 | pub struct BackoffStrategy { 7 | config: BackoffConfig, 8 | current_attempt: usize, 9 | last_attempt: Option, 10 | } 11 | 12 | impl BackoffStrategy { 13 | pub fn new(config: BackoffConfig) -> Self { 14 | Self { 15 | config, 16 | current_attempt: 0, 17 | last_attempt: None, 18 | } 19 | } 20 | 21 | pub fn next_backoff(&mut self) -> Option { 22 | if self.current_attempt >= self.config.max_retries as usize { 23 | return None; 24 | } 25 | 26 | let interval = self.config.initial_interval.as_secs_f64() 27 | * self.config.multiplier.powi(self.current_attempt as i32); 28 | 29 | let interval = 30 | Duration::from_secs_f64(interval.min(self.config.max_interval.as_secs_f64())); 31 | 32 | self.current_attempt += 1; 33 | self.last_attempt = Some(Instant::now()); 34 | Some(interval) 35 | } 36 | 37 | pub fn reset(&mut self) { 38 | self.current_attempt = 0; 39 | self.last_attempt = None; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /dist/arch/PKGBUILD: -------------------------------------------------------------------------------- 1 | # Maintainer: Artur Wyszyński 2 | pkgname=modbus-relay 3 | pkgver=2025.11.0 4 | pkgrel=1 5 | pkgdesc="A high performance Modbus TCP to RTU relay" 6 | arch=('x86_64' 'aarch64' 'armv7h') 7 | url="https://github.com/aljen/modbus-relay" 8 | license=('MIT' 'Apache') 9 | depends=('systemd-libs') 10 | makedepends=('cargo' 'git') 11 | backup=('etc/modbus-relay/config.yaml') 12 | options=('!strip' '!debug') 13 | 14 | build() { 15 | cd "$srcdir/$pkgname" 16 | cargo build --release 17 | } 18 | 19 | package() { 20 | cd "$srcdir/$pkgname" 21 | 22 | # Binary 23 | install -Dm755 "target/release/$pkgname" "$pkgdir/usr/bin/$pkgname" 24 | 25 | # Config 26 | install -Dm644 "config/config.example.yaml" "$pkgdir/etc/$pkgname/config.yaml" 27 | 28 | # Systemd service 29 | install -Dm644 "dist/arch/modbus-relay.service" "$pkgdir/usr/lib/systemd/system/$pkgname.service" 30 | 31 | # Systemd sysusers 32 | install -Dm644 "dist/arch/modbus-relay.sysusers" "$pkgdir/usr/lib/sysusers.d/$pkgname.conf" 33 | 34 | # Documentation 35 | install -Dm644 "README.md" "$pkgdir/usr/share/doc/$pkgname/README.md" 36 | install -Dm644 "LICENSE-MIT" "$pkgdir/usr/share/licenses/$pkgname/LICENSE-MIT" 37 | install -Dm644 "LICENSE-APACHE" "$pkgdir/usr/share/licenses/$pkgname/LICENSE-APACHE" 38 | } 39 | -------------------------------------------------------------------------------- /src/config/connection.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use super::BackoffConfig; 6 | 7 | /// Configuration for managing connections 8 | #[derive(Debug, Clone, Serialize, Deserialize)] 9 | #[serde(deny_unknown_fields)] 10 | pub struct Config { 11 | /// Maximum number of concurrent connections 12 | pub max_connections: u64, 13 | /// Time after which an idle connection will be closed 14 | #[serde(with = "humantime_serde")] 15 | pub idle_timeout: Duration, 16 | /// Time after which a connection with errors will be closed 17 | #[serde(with = "humantime_serde")] 18 | pub error_timeout: Duration, 19 | /// Timeout for establishing a connection 20 | #[serde(with = "humantime_serde")] 21 | pub connect_timeout: Duration, 22 | /// Limits for specific IP addresses 23 | pub per_ip_limits: Option, 24 | /// Parameters for backoff strategy 25 | pub backoff: BackoffConfig, 26 | } 27 | 28 | impl Default for Config { 29 | fn default() -> Self { 30 | Self { 31 | max_connections: 100, 32 | idle_timeout: Duration::from_secs(60), 33 | error_timeout: Duration::from_secs(300), 34 | connect_timeout: Duration::from_secs(5), 35 | per_ip_limits: Some(10), 36 | backoff: BackoffConfig { 37 | initial_interval: Duration::from_millis(100), 38 | max_interval: Duration::from_secs(30), 39 | multiplier: 2.0, 40 | max_retries: 5, 41 | }, 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/config/logging.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use tracing::level_filters::LevelFilter; 3 | 4 | #[derive(Debug, Clone, Serialize, Deserialize)] 5 | #[serde(deny_unknown_fields)] 6 | pub struct Config { 7 | /// Directory to store log files 8 | pub log_dir: String, 9 | 10 | /// Enable trace-level logging for frame contents 11 | pub trace_frames: bool, 12 | 13 | /// Minimum log level for console output 14 | pub level: String, 15 | 16 | /// Log format (pretty or json) 17 | pub format: String, 18 | 19 | /// Whether to include source code location in logs 20 | pub include_location: bool, 21 | 22 | /// Whether to include thread IDs in logs 23 | pub thread_ids: bool, 24 | 25 | /// Whether to include thread names in logs 26 | pub thread_names: bool, 27 | } 28 | 29 | impl Default for Config { 30 | fn default() -> Self { 31 | Self { 32 | log_dir: "logs".to_string(), 33 | trace_frames: false, 34 | level: "info".to_string(), 35 | format: "pretty".to_string(), 36 | include_location: false, 37 | thread_ids: false, 38 | thread_names: false, 39 | } 40 | } 41 | } 42 | 43 | impl Config { 44 | pub fn get_level_filter(&self) -> LevelFilter { 45 | match self.level.to_lowercase().as_str() { 46 | "error" => LevelFilter::ERROR, 47 | "warn" => LevelFilter::WARN, 48 | "info" => LevelFilter::INFO, 49 | "debug" => LevelFilter::DEBUG, 50 | "trace" => LevelFilter::TRACE, 51 | _ => LevelFilter::INFO, // Fallback to INFO if invalid 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/config/rtu.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use super::{DataBits, Parity, RtsType, StopBits}; 6 | 7 | #[derive(Debug, Clone, Serialize, Deserialize)] 8 | #[serde(deny_unknown_fields)] 9 | pub struct Config { 10 | pub device: String, 11 | pub baud_rate: u32, 12 | pub data_bits: DataBits, 13 | pub parity: Parity, 14 | pub stop_bits: StopBits, 15 | 16 | /// Flow control settings for the serial port 17 | pub rts_type: RtsType, 18 | pub rts_delay_us: u64, 19 | 20 | /// Whether to flush the serial port after writing 21 | pub flush_after_write: bool, 22 | 23 | /// Timeout for the entire transaction (request + response) 24 | #[serde(with = "humantime_serde")] 25 | pub transaction_timeout: Duration, 26 | 27 | /// Timeout for individual read/write operations on serial port 28 | #[serde(with = "humantime_serde")] 29 | pub serial_timeout: Duration, 30 | 31 | /// Maximum size of the request/response buffer 32 | pub max_frame_size: u64, 33 | } 34 | 35 | impl Default for Config { 36 | fn default() -> Self { 37 | Self { 38 | device: "/dev/ttyAMA0".to_string(), 39 | baud_rate: 9600, 40 | data_bits: DataBits::default(), 41 | parity: Parity::default(), 42 | stop_bits: StopBits::default(), 43 | rts_type: RtsType::default(), 44 | rts_delay_us: 3500, 45 | flush_after_write: true, 46 | transaction_timeout: Duration::from_secs(5), 47 | serial_timeout: Duration::from_secs(1), 48 | max_frame_size: 256, 49 | } 50 | } 51 | } 52 | 53 | impl Config { 54 | pub fn serial_port_info(&self) -> String { 55 | format!( 56 | "{} ({} baud, {} data bits, {} parity, {} stop bits)", 57 | self.device, self.baud_rate, self.data_bits, self.parity, self.stop_bits 58 | ) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | # Default configuration for modbus-relay 2 | 3 | tcp: 4 | # TCP server bind address 5 | bind_addr: "127.0.0.1" 6 | # TCP server port 7 | bind_port: 502 8 | # TCP keepalive probe interval (e.g. "60s", "2m") 9 | # This is how often the server will check if client connections are still alive 10 | keep_alive: "60s" 11 | 12 | rtu: 13 | # Serial device path 14 | device: "/dev/ttyAMA0" 15 | # Baud rate 16 | baud_rate: 9600 17 | # Data bits (5-8) 18 | data_bits: 8 19 | # Parity (none, odd, even) 20 | parity: "none" 21 | # Stop bits ("one", "two") 22 | stop_bits: "one" 23 | # Whether to flush after write 24 | flush_after_write: true 25 | 26 | # Optional RTS configuration ("up", "down", "none") 27 | rts_type: "down" 28 | rts_delay_us: 3500 29 | 30 | # Transaction timeout 31 | transaction_timeout: "5s" 32 | # Serial port timeout 33 | serial_timeout: "1s" 34 | # Maximum frame size 35 | max_frame_size: 256 36 | 37 | http: 38 | # Enabled 39 | enabled: true 40 | # HTTP API bind address 41 | bind_addr: "127.0.0.1" 42 | # HTTP API port 43 | bind_port: 8080 44 | # Metrics enabled 45 | metrics_enabled: true 46 | 47 | logging: 48 | # Directory to store log files 49 | log_dir: "logs" 50 | # Trace modbus frames 51 | trace_frames: false 52 | # Log level (trace, debug, info, warn, error) 53 | level: "trace" 54 | # Log format (pretty, json) 55 | format: "pretty" 56 | # Use file:line 57 | include_location: false 58 | # Log thread ids 59 | thread_ids: false 60 | # Log thread names 61 | thread_names: true 62 | 63 | connection: 64 | # Maximum number of concurrent connections 65 | max_connections: 100 66 | # Time after which an idle connection will be closed 67 | idle_timeout: "60s" 68 | # Time after which a connection with errors will be closed 69 | error_timeout: "300s" 70 | # Timeout for establishing a connection 71 | connect_timeout: "5s" 72 | # Optional per IP limits 73 | per_ip_limits: 10 74 | # Parameters for backoff strategy 75 | backoff: 76 | # Initial wait time 77 | initial_interval: "100ms" 78 | # Maximum wait time 79 | max_interval: "30s" 80 | # Multiplier for each subsequent attempt 81 | multiplier: 2.0 82 | # Maximum number of attempts 83 | max_retries: 5 84 | 85 | -------------------------------------------------------------------------------- /config/config.example.yaml: -------------------------------------------------------------------------------- 1 | # Default configuration for modbus-relay 2 | 3 | tcp: 4 | # TCP server bind address 5 | bind_addr: "127.0.0.1" 6 | # TCP server port 7 | bind_port: 502 8 | # TCP keepalive probe interval (e.g. "60s", "2m") 9 | # This is how often the server will check if client connections are still alive 10 | keep_alive: "60s" 11 | 12 | rtu: 13 | # Serial device path 14 | device: "/dev/ttyAMA0" 15 | # Baud rate 16 | baud_rate: 9600 17 | # Data bits (5-8) 18 | data_bits: 8 19 | # Parity (none, odd, even) 20 | parity: "none" 21 | # Stop bits ("one", "two") 22 | stop_bits: "one" 23 | # Whether to flush after write 24 | flush_after_write: true 25 | 26 | # Optional RTS configuration ("up", "down", "none") 27 | rts_type: "down" 28 | rts_delay_us: 3500 29 | 30 | # Transaction timeout 31 | transaction_timeout: "5s" 32 | # Serial port timeout 33 | serial_timeout: "1s" 34 | # Maximum frame size 35 | max_frame_size: 256 36 | 37 | http: 38 | # Enabled 39 | enabled: true 40 | # HTTP API bind address 41 | bind_addr: "127.0.0.1" 42 | # HTTP API port 43 | bind_port: 8080 44 | # Metrics enabled 45 | metrics_enabled: true 46 | 47 | logging: 48 | # Directory to store log files 49 | log_dir: "logs" 50 | # Trace modbus frames 51 | trace_frames: false 52 | # Log level (trace, debug, info, warn, error) 53 | level: "trace" 54 | # Log format (pretty, json) 55 | format: "pretty" 56 | # Use file:line 57 | include_location: false 58 | # Log thread ids 59 | thread_ids: false 60 | # Log thread names 61 | thread_names: true 62 | 63 | connection: 64 | # Maximum number of concurrent connections 65 | max_connections: 100 66 | # Time after which an idle connection will be closed 67 | idle_timeout: "60s" 68 | # Time after which a connection with errors will be closed 69 | error_timeout: "300s" 70 | # Timeout for establishing a connection 71 | connect_timeout: "5s" 72 | # Optional per IP limits 73 | per_ip_limits: 10 74 | # Parameters for backoff strategy 75 | backoff: 76 | # Initial wait time 77 | initial_interval: "100ms" 78 | # Maximum wait time 79 | max_interval: "30s" 80 | # Multiplier for each subsequent attempt 81 | multiplier: 2.0 82 | # Maximum number of attempts 83 | max_retries: 5 84 | 85 | -------------------------------------------------------------------------------- /src/errors/kinds/protocol_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum ProtocolErrorKind { 3 | InvalidFunction, 4 | InvalidDataAddress, 5 | InvalidDataValue, 6 | ServerFailure, 7 | Acknowledge, 8 | ServerBusy, 9 | GatewayPathUnavailable, 10 | GatewayTargetFailedToRespond, 11 | InvalidProtocolId, 12 | InvalidTransactionId, 13 | InvalidUnitId, 14 | InvalidPdu, 15 | } 16 | 17 | impl std::fmt::Display for ProtocolErrorKind { 18 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 19 | match self { 20 | Self::InvalidFunction => write!(f, "Invalid function code"), 21 | Self::InvalidDataAddress => write!(f, "Invalid data address"), 22 | Self::InvalidDataValue => write!(f, "Invalid data value"), 23 | Self::ServerFailure => write!(f, "Server device failure"), 24 | Self::Acknowledge => write!(f, "Acknowledge"), 25 | Self::ServerBusy => write!(f, "Server device busy"), 26 | Self::GatewayPathUnavailable => write!(f, "Gateway path unavailable"), 27 | Self::GatewayTargetFailedToRespond => { 28 | write!(f, "Gateway target device failed to respond") 29 | } 30 | Self::InvalidProtocolId => write!(f, "Invalid protocol ID"), 31 | Self::InvalidTransactionId => write!(f, "Invalid transaction ID"), 32 | Self::InvalidUnitId => write!(f, "Invalid unit ID"), 33 | Self::InvalidPdu => write!(f, "Invalid PDU format"), 34 | } 35 | } 36 | } 37 | 38 | impl ProtocolErrorKind { 39 | pub fn to_exception_code(&self) -> u8 { 40 | match self { 41 | Self::InvalidFunction => 0x01, 42 | Self::InvalidDataAddress => 0x02, 43 | Self::InvalidDataValue => 0x03, 44 | Self::ServerFailure => 0x04, 45 | Self::Acknowledge => 0x05, 46 | Self::ServerBusy => 0x06, 47 | Self::GatewayPathUnavailable => 0x0A, 48 | Self::GatewayTargetFailedToRespond => 0x0B, 49 | _ => 0x04, // Map unknown errors to server failure 50 | } 51 | } 52 | 53 | pub fn from_exception_code(code: u8) -> Option { 54 | match code { 55 | 0x01 => Some(Self::InvalidFunction), 56 | 0x02 => Some(Self::InvalidDataAddress), 57 | 0x03 => Some(Self::InvalidDataValue), 58 | 0x04 => Some(Self::ServerFailure), 59 | 0x05 => Some(Self::Acknowledge), 60 | 0x06 => Some(Self::ServerBusy), 61 | 0x0A => Some(Self::GatewayPathUnavailable), 62 | 0x0B => Some(Self::GatewayTargetFailedToRespond), 63 | _ => None, 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/connection/stats/connection.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | net::SocketAddr, 4 | time::{Duration, SystemTime}, 5 | }; 6 | 7 | use serde::Serialize; 8 | 9 | use super::{ClientStats, IpStats}; 10 | 11 | #[derive(Debug, Serialize)] 12 | pub struct Stats { 13 | pub total_connections: u64, 14 | pub active_connections: usize, 15 | pub total_requests: u64, 16 | pub total_errors: u64, 17 | pub requests_per_second: f64, 18 | pub avg_response_time_ms: u64, 19 | pub per_ip_stats: HashMap, 20 | } 21 | 22 | impl Stats { 23 | pub fn from_client_stats(stats: &HashMap) -> Self { 24 | let mut total_active = 0; 25 | let mut total_requests = 0; 26 | let mut total_errors = 0; 27 | let mut total_response_time = 0u64; 28 | let mut response_time_count = 0; 29 | let mut per_ip = HashMap::new(); 30 | 31 | // Calculate totals and build per-IP stats 32 | for (addr, client) in stats { 33 | total_active += client.active_connections; 34 | total_requests += client.total_requests; 35 | total_errors += client.total_errors; 36 | 37 | if client.avg_response_time_ms > 0 { 38 | total_response_time += client.avg_response_time_ms; 39 | response_time_count += 1; 40 | } 41 | 42 | per_ip.insert( 43 | *addr, 44 | IpStats { 45 | active_connections: client.active_connections, 46 | total_requests: client.total_requests, 47 | total_errors: client.total_errors, 48 | last_active: client.last_active, 49 | last_error: client.last_error, 50 | avg_response_time_ms: client.avg_response_time_ms, 51 | }, 52 | ); 53 | } 54 | 55 | Self { 56 | total_connections: total_active as u64, 57 | active_connections: total_active, 58 | total_requests, 59 | total_errors, 60 | requests_per_second: Self::calculate_requests_per_second(stats), 61 | avg_response_time_ms: if response_time_count > 0 { 62 | total_response_time / response_time_count 63 | } else { 64 | 0 65 | }, 66 | per_ip_stats: per_ip, 67 | } 68 | } 69 | 70 | fn calculate_requests_per_second(stats: &HashMap) -> f64 { 71 | let now = SystemTime::now(); 72 | let window = Duration::from_secs(60); 73 | let mut recent_requests = 0; 74 | 75 | for client in stats.values() { 76 | if let Ok(duration) = now.duration_since(client.last_active) 77 | && duration <= window 78 | { 79 | recent_requests += client.total_requests as usize; 80 | } 81 | } 82 | 83 | recent_requests as f64 / window.as_secs_f64() 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/errors/transport.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | use thiserror::Error; 3 | use tokio::time::error::Elapsed; 4 | 5 | use super::{IoOperation, RtsError, SerialErrorKind}; 6 | 7 | #[derive(Error, Debug)] 8 | pub enum TransportError { 9 | #[error("Serial port error: {kind} on {port} - {details}")] 10 | Serial { 11 | kind: SerialErrorKind, 12 | port: String, 13 | details: String, 14 | #[source] 15 | source: Option, 16 | }, 17 | 18 | #[error("Network error: {0}")] 19 | Network(std::io::Error), 20 | 21 | #[error("I/O error: {operation} failed on {details}")] 22 | Io { 23 | operation: IoOperation, 24 | details: String, 25 | #[source] 26 | source: std::io::Error, 27 | }, 28 | 29 | #[error("Transaction timeout after {elapsed:?}, limit was {limit:?}")] 30 | Timeout { 31 | elapsed: Duration, 32 | limit: Duration, 33 | #[source] 34 | source: Elapsed, 35 | }, 36 | 37 | #[error("No response received after {attempts} attempts over {elapsed:?}")] 38 | NoResponse { attempts: u8, elapsed: Duration }, 39 | 40 | #[error("RTS error: {0}")] 41 | Rts(#[from] RtsError), 42 | } 43 | 44 | impl From for TransportError { 45 | fn from(err: serialport::Error) -> Self { 46 | match err.kind { 47 | serialport::ErrorKind::NoDevice => TransportError::Serial { 48 | kind: SerialErrorKind::OpenFailed, 49 | port: err.to_string(), 50 | details: "Device not found".into(), 51 | source: Some(err), 52 | }, 53 | serialport::ErrorKind::InvalidInput => TransportError::Serial { 54 | kind: SerialErrorKind::ConfigurationFailed, 55 | port: err.to_string(), 56 | details: "Invalid configuration".into(), 57 | source: Some(err), 58 | }, 59 | serialport::ErrorKind::Io(io_err) => TransportError::Io { 60 | operation: match io_err { 61 | std::io::ErrorKind::NotFound => IoOperation::Configure, 62 | std::io::ErrorKind::PermissionDenied => IoOperation::Configure, 63 | std::io::ErrorKind::TimedOut => IoOperation::Read, 64 | std::io::ErrorKind::WriteZero => IoOperation::Write, 65 | _ => IoOperation::Control, 66 | }, 67 | details: io_err.to_string(), 68 | source: std::io::Error::new(io_err, err.description), 69 | }, 70 | _ => TransportError::Serial { 71 | kind: SerialErrorKind::OpenFailed, 72 | port: err.to_string(), 73 | details: err.to_string(), 74 | source: Some(err), 75 | }, 76 | } 77 | } 78 | } 79 | 80 | impl From for TransportError { 81 | fn from(err: std::io::Error) -> Self { 82 | TransportError::Io { 83 | operation: match err.kind() { 84 | std::io::ErrorKind::TimedOut => IoOperation::Read, 85 | std::io::ErrorKind::WouldBlock => IoOperation::Read, 86 | std::io::ErrorKind::WriteZero => IoOperation::Write, 87 | std::io::ErrorKind::Interrupted => IoOperation::Control, 88 | _ => IoOperation::Control, 89 | }, 90 | details: err.to_string(), 91 | source: err, 92 | } 93 | } 94 | } 95 | 96 | impl From for TransportError { 97 | fn from(err: Elapsed) -> Self { 98 | TransportError::Timeout { 99 | elapsed: Duration::from_secs(1), // W Elapsed nie ma duration(), używamy stałej 100 | limit: Duration::from_secs(1), // TODO(aljen): Pass the actual limit from configuration 101 | source: err, 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "modbus-relay" 3 | version = "2025.11.0" 4 | edition = "2024" 5 | authors = ["Artur Wyszyński "] 6 | description = "A high performance Modbus TCP to RTU relay" 7 | homepage = "https://github.com/aljen/modbus-relay" 8 | repository = "https://github.com/aljen/modbus-relay" 9 | documentation = "https://docs.rs/modbus-relay" 10 | readme = "README.md" 11 | license = "MIT OR Apache-2.0" 12 | rust-version = "1.91" 13 | keywords = ["hardware", "modbus", "client", "server", "relay"] 14 | categories = ["embedded", "hardware-support", "network-programming"] 15 | 16 | exclude = [ 17 | ".github/**/*", 18 | "docs/**/*", 19 | "tests/**/*", 20 | "examples/**/*", 21 | "benches/**/*", 22 | ] 23 | 24 | include = [ 25 | "src/**/*", 26 | "LICENSE*", 27 | "README.md", 28 | "CHANGELOG.md", 29 | "config/config.example.yaml", 30 | ] 31 | 32 | [badges] 33 | maintenance = { status = "actively-developed" } 34 | 35 | [dependencies] 36 | axum = "0.8.6" 37 | clap = { version = "4.5.51", features = ["derive"] } 38 | config = "0.15.18" 39 | futures = "0.3.31" 40 | hex = "0.4.3" 41 | humantime-serde = "1.1.1" 42 | libc = "0.2.177" 43 | serde = { version = "1.0.228", features = ["derive"] } 44 | serde_yaml = "0.9.34" 45 | serialport = "4.8.1" 46 | socket2 = "0.6.1" 47 | thiserror = "2.0.17" 48 | time = { version = "0.3.44", features = ["local-offset"] } 49 | tokio = { version = "1.48.0", features = ["full"] } 50 | tracing = "0.1.41" 51 | tracing-appender = "0.2.3" 52 | tracing-subscriber = { version = "0.3.20", features = [ 53 | "env-filter", 54 | "json", 55 | "time", 56 | ] } 57 | 58 | [dev-dependencies] 59 | tower = "0.5.2" 60 | tempfile = "3.23.0" 61 | serial_test = "3.2.0" 62 | 63 | [profile.release] 64 | lto = true # Link Time Optimization 65 | codegen-units = 1 # Maximize size reduction optimizations 66 | opt-level = 3 # Maximum optimization 67 | panic = 'abort' # Remove panic unwinding code 68 | strip = true # Remove debug symbols 69 | 70 | [profile.dev] 71 | opt-level = 0 # No optimizations for faster compilation 72 | debug = true # Full debug info 73 | 74 | # TODO: 75 | # debug-logging - includes debug logging 76 | # metrics - includes metrics for Prometheus 77 | # tls - adds TLS support 78 | 79 | [[bin]] 80 | name = "modbus-relay" 81 | path = "src/main.rs" 82 | 83 | [lib] 84 | name = "modbus_relay" 85 | path = "src/lib.rs" 86 | 87 | [package.metadata.docs.rs] 88 | all-features = true 89 | rustdoc-args = ["--cfg", "docsrs"] 90 | 91 | [package.metadata.release] 92 | sign-commit = true 93 | sign-tag = true 94 | pre-release-commit-message = "chore: release {{version}}" 95 | tag-message = "release: {{version}}" 96 | tag-name = "v{{version}}" 97 | 98 | [package.metadata.deb] 99 | maintainer = "Artur Wyszyński " 100 | copyright = "2024-2025, Artur Wyszyński " 101 | extended-description = """ 102 | High-performance Modbus TCP to RTU relay service. 103 | Supports various temperature sensors and provides efficient data relay capabilities. 104 | """ 105 | depends = "$auto" 106 | section = "net" 107 | priority = "optional" 108 | assets = [ 109 | [ 110 | "target/release/modbus-relay", 111 | "usr/bin/", 112 | "755", 113 | ], 114 | [ 115 | "config/config.example.yaml", 116 | "etc/modbus-relay/config.yaml", 117 | "644", 118 | ], 119 | [ 120 | "dist/debian/package/modbus-relay.service", 121 | "lib/systemd/system/modbus-relay.service", 122 | "644", 123 | ], 124 | [ 125 | "LICENSE-MIT", 126 | "usr/share/doc/modbus-relay/LICENSE-MIT", 127 | "644", 128 | ], 129 | [ 130 | "LICENSE-APACHE", 131 | "usr/share/doc/modbus-relay/LICENSE-APACHE", 132 | "644", 133 | ], 134 | ] 135 | conf-files = ["/etc/modbus-relay/config.yaml"] 136 | systemd-units = { unit-name = "modbus-relay" } 137 | maintainer-scripts = "dist/debian/maintainer-scripts" 138 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | workflow_call: 9 | 10 | env: 11 | CARGO_TERM_COLOR: always 12 | CARGO_INCREMENTAL: 0 13 | CARGO_NET_RETRY: 10 14 | RUSTUP_MAX_RETRIES: 10 15 | RUST_BACKTRACE: short 16 | PKG_CONFIG_ALLOW_CROSS: 1 17 | 18 | jobs: 19 | check: 20 | name: Check 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - name: Install system dependencies 26 | run: | 27 | sudo apt-get update 28 | sudo apt-get install -y libudev-dev pkg-config 29 | 30 | - name: Install Rust toolchain 31 | uses: dtolnay/rust-toolchain@stable 32 | with: 33 | components: rustfmt, clippy 34 | 35 | - name: Cache cargo 36 | uses: actions/cache@v3 37 | with: 38 | path: | 39 | ~/.cargo/registry 40 | ~/.cargo/git 41 | target 42 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 43 | 44 | - name: Check formatting 45 | run: cargo fmt --all -- --check 46 | 47 | - name: Check clippy 48 | run: cargo clippy -- -D warnings 49 | 50 | - name: Run tests 51 | run: cargo test --all-features 52 | 53 | build-linux: 54 | name: Build Linux 55 | needs: check 56 | runs-on: ubuntu-latest 57 | strategy: 58 | matrix: 59 | target: 60 | - x86_64-unknown-linux-gnu 61 | - aarch64-unknown-linux-gnu 62 | - armv7-unknown-linux-gnueabihf 63 | 64 | steps: 65 | - uses: actions/checkout@v4 66 | 67 | - name: Install Rust toolchain 68 | uses: dtolnay/rust-toolchain@stable 69 | with: 70 | targets: ${{ matrix.target }} 71 | 72 | - name: Install cross 73 | uses: taiki-e/install-action@cross 74 | 75 | - name: Cache cargo 76 | uses: actions/cache@v3 77 | with: 78 | path: | 79 | ~/.cargo/registry 80 | ~/.cargo/git 81 | target 82 | key: ${{ runner.os }}-cargo-${{ matrix.target }}-${{ hashFiles('**/Cargo.lock') }} 83 | restore-keys: | 84 | ${{ runner.os }}-cargo-${{ matrix.target }}- 85 | 86 | - name: Cache cross 87 | uses: actions/cache@v3 88 | if: matrix.target != 'x86_64-unknown-linux-gnu' 89 | with: 90 | path: ~/.cargo/.cross 91 | key: ${{ runner.os }}-cross-${{ matrix.target }}-${{ hashFiles('Cross.toml') }} 92 | restore-keys: | 93 | ${{ runner.os }}-cross-${{ matrix.target }}- 94 | 95 | - name: Install target specific dependencies 96 | if: matrix.target == 'x86_64-unknown-linux-gnu' 97 | run: | 98 | sudo apt-get update 99 | sudo apt-get install -y libudev-dev dpkg-dev 100 | 101 | - name: Set library path 102 | if: matrix.target == 'x86_64-unknown-linux-gnu' 103 | run: | 104 | echo "LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH" >> $GITHUB_ENV 105 | echo "LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LIBRARY_PATH" >> $GITHUB_ENV 106 | 107 | - name: Build 108 | env: 109 | PKG_CONFIG_ALLOW_CROSS: "1" 110 | run: | 111 | if [ "${{ matrix.target }}" = "x86_64-unknown-linux-gnu" ]; then 112 | # Native build uses standard system paths 113 | cargo build --release --target ${{ matrix.target }} 114 | else 115 | # Cross compilation requires special paths 116 | if [ "${{ matrix.target }}" = "armv7-unknown-linux-gnueabihf" ]; then 117 | PKG_PATH="/usr/lib/arm-linux-gnueabihf/pkgconfig" 118 | elif [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then 119 | PKG_PATH="/usr/lib/aarch64-linux-gnu/pkgconfig" 120 | fi 121 | PKG_CONFIG_PATH="$PKG_PATH" \ 122 | PKG_CONFIG_SYSROOT_DIR="/usr" \ 123 | PKG_CONFIG_LIBDIR="$PKG_PATH" \ 124 | cross build --release --target ${{ matrix.target }} 125 | fi 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # modbus-relay 4 | 5 | 🚀 High-performance Modbus TCP to RTU relay written in Rust 6 | 7 | [![Crates.io](https://img.shields.io/crates/v/modbus-relay.svg)](https://crates.io/crates/modbus-relay) 8 | [![Documentation](https://docs.rs/modbus-relay/badge.svg)](https://docs.rs/modbus-relay) 9 | [![License](https://img.shields.io/badge/license-MIT%2FApache--2.0-blue.svg)](LICENSE-MIT) 10 | [![Build Status](https://github.com/aljen/modbus-relay/workflows/CI/badge.svg)](https://github.com/aljen/modbus-relay/actions) 11 | 12 | [Features](#features) • 13 | [Installation](#installation) • 14 | [Usage](#usage) • 15 | [Configuration](#configuration) • 16 | [Monitoring](#monitoring) • 17 | [Contributing](#contributing) 18 | 19 |
20 | 21 | ## 🌟 Features 22 | 23 | - 🔄 Transparent TCP to RTU protocol conversion 24 | - 🚀 High-performance asynchronous I/O with Tokio 25 | - 🔧 Advanced RS485 support with configurable RTS control 26 | - 🛡️ Robust error handling and connection management 27 | - ⚡ Zero-copy buffer handling for optimal performance 28 | - 📝 Structured logging with multiple output formats 29 | - 🔌 Connection pooling with per-IP limits 30 | - 🔄 Automatic reconnection with configurable backoff 31 | - 🎯 Comprehensive test suite 32 | - 📊 Built-in metrics and monitoring via HTTP API 33 | 34 | ## 🚀 Quick Start 35 | 36 | ### Installation 37 | 38 | ```bash 39 | # Install from crates.io 40 | cargo install modbus-relay 41 | 42 | # Or build from source 43 | git clone https://github.com/aljen/modbus-relay 44 | cd modbus-relay 45 | cargo build --release 46 | ``` 47 | 48 | ### Basic Usage 49 | 50 | ```bash 51 | # Generate default configuration 52 | modbus-relay --dump-default-config > config.yaml 53 | 54 | # Run with custom config 55 | modbus-relay -c /path/to/config.yaml 56 | 57 | # Run with default settings 58 | modbus-relay 59 | ``` 60 | 61 | ## ⚙️ Configuration 62 | 63 | Configuration is managed through YAML files. Here's a complete example (`config.yaml`): 64 | 65 | ```yaml 66 | tcp: 67 | bind_addr: "0.0.0.0" 68 | bind_port: 502 69 | 70 | rtu: 71 | device: "/dev/ttyUSB0" 72 | baud_rate: 9600 73 | data_bits: 8 74 | parity: "none" 75 | stop_bits: 1 76 | flush_after_write: true 77 | rts_type: "none" # Options: none, up, down 78 | rts_delay_us: 0 79 | transaction_timeout: "1s" 80 | serial_timeout: "100ms" 81 | max_frame_size: 256 82 | 83 | http: 84 | enabled: true 85 | bind_addr: "127.0.0.1" 86 | bind_port: 8080 87 | metrics_enabled: true 88 | 89 | connection: 90 | max_connections: 100 91 | idle_timeout: "60s" 92 | connect_timeout: "5s" 93 | per_ip_limits: 10 94 | backoff: 95 | initial_interval: "100ms" 96 | max_interval: "30s" 97 | multiplier: 2.0 98 | max_retries: 5 99 | 100 | logging: 101 | trace_frames: false 102 | log_level: "info" 103 | format: "pretty" # Options: pretty, json 104 | include_location: false 105 | ``` 106 | 107 | ## 📊 Monitoring 108 | 109 | The HTTP API provides basic monitoring endpoints: 110 | 111 | - `GET /health` - Health check endpoint 112 | - `GET /status` - Detailed status information 113 | 114 | Planned monitoring features: 115 | - Prometheus metrics support 116 | - OpenTelemetry integration 117 | - Advanced connection statistics 118 | - Detailed performance metrics 119 | 120 | ## 🔍 Examples 121 | 122 | ### Industrial Automation Setup 123 | 124 | ![modbus_relay.png](docs/modbus_relay.png) 125 | 126 | Example setup running on Raspberry Pi with multiple Modbus RTU devices connected via RS485. 127 | 128 | ## 🛠️ Tech Stack 129 | 130 | - [tokio](https://tokio.rs) - Asynchronous runtime 131 | - [tokio-serial](https://docs.rs/tokio-serial) - Async serial port handling 132 | - [tracing](https://docs.rs/tracing) - Structured logging 133 | - [config](https://docs.rs/config) - Configuration management 134 | - [axum](https://docs.rs/axum) - HTTP server framework 135 | 136 | ### Coming Soon 137 | - Prometheus metrics integration 138 | - OpenTelemetry support 139 | 140 | ## 📚 Documentation 141 | 142 | - [API Documentation](https://docs.rs/modbus-relay) 143 | - [Configuration Guide](docs/configuration.md) 144 | - [Metrics Reference](docs/metrics.md) 145 | - [Troubleshooting Guide](docs/troubleshooting.md) 146 | 147 | ## 🤝 Contributing 148 | 149 | Contributions are welcome! Please check out our: 150 | - [Contributing Guidelines](CONTRIBUTING.md) 151 | - [Code of Conduct](CODE_OF_CONDUCT.md) 152 | 153 | ## 📄 License 154 | 155 | This project is licensed under either of 156 | 157 | * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 158 | * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 159 | 160 | at your option. 161 | -------------------------------------------------------------------------------- /src/errors/relay.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use super::{ 4 | BackoffError, ClientErrorKind, ConfigValidationError, ConnectionError, FrameError, 5 | FrameErrorKind, FrameFormatKind, FrameSizeKind, InitializationError, ProtocolErrorKind, 6 | RtsError, TransportError, 7 | }; 8 | 9 | #[derive(Error, Debug)] 10 | pub enum RelayError { 11 | #[error("Transport error: {0}")] 12 | Transport(#[from] TransportError), 13 | 14 | #[error("Protocol error: {kind} - {details}")] 15 | Protocol { 16 | kind: ProtocolErrorKind, 17 | details: String, 18 | source: Option>, 19 | }, 20 | 21 | #[error("Configuration error: {0}")] 22 | Config(#[from] ConfigValidationError), 23 | 24 | #[error("Frame error: {0}")] 25 | Frame(#[from] FrameError), 26 | 27 | #[error("Connection error: {0}")] 28 | Connection(#[from] ConnectionError), 29 | 30 | #[error("Client error: {kind} from {client_addr} - {details}")] 31 | Client { 32 | kind: ClientErrorKind, 33 | client_addr: std::net::SocketAddr, 34 | details: String, 35 | }, 36 | 37 | #[error("Initialization error: {0}")] 38 | Init(#[from] InitializationError), 39 | } 40 | 41 | impl RelayError { 42 | pub fn protocol(kind: ProtocolErrorKind, details: impl Into) -> Self { 43 | RelayError::Protocol { 44 | kind, 45 | details: details.into(), 46 | source: None, 47 | } 48 | } 49 | 50 | pub fn protocol_with_source( 51 | kind: ProtocolErrorKind, 52 | details: impl Into, 53 | source: impl Into>, 54 | ) -> Self { 55 | RelayError::Protocol { 56 | kind, 57 | details: details.into(), 58 | source: Some(source.into()), 59 | } 60 | } 61 | 62 | pub fn connection(kind: ConnectionError) -> Self { 63 | RelayError::Connection(kind) 64 | } 65 | 66 | pub fn config(kind: ConfigValidationError) -> Self { 67 | RelayError::Config(kind) 68 | } 69 | 70 | pub fn frame( 71 | kind: FrameErrorKind, 72 | details: impl Into, 73 | frame_data: Option>, 74 | ) -> Self { 75 | let details = details.into(); 76 | match kind { 77 | FrameErrorKind::TooShort | FrameErrorKind::TooLong => { 78 | RelayError::Frame(FrameError::Size { 79 | kind: match kind { 80 | FrameErrorKind::TooShort => FrameSizeKind::TooShort, 81 | FrameErrorKind::TooLong => FrameSizeKind::TooLong, 82 | _ => unreachable!(), 83 | }, 84 | details, 85 | frame_data, 86 | }) 87 | } 88 | FrameErrorKind::InvalidFormat 89 | | FrameErrorKind::InvalidUnitId 90 | | FrameErrorKind::InvalidHeader 91 | | FrameErrorKind::UnexpectedResponse => RelayError::Frame(FrameError::Format { 92 | kind: match kind { 93 | FrameErrorKind::InvalidFormat => FrameFormatKind::InvalidFormat, 94 | FrameErrorKind::InvalidHeader => FrameFormatKind::InvalidHeader, 95 | FrameErrorKind::UnexpectedResponse => FrameFormatKind::UnexpectedResponse, 96 | _ => unreachable!(), 97 | }, 98 | details, 99 | frame_data, 100 | }), 101 | FrameErrorKind::InvalidCrc => { 102 | if let Some(frame_data) = frame_data { 103 | let frame_hex = hex::encode(&frame_data); 104 | RelayError::Frame(FrameError::Crc { 105 | calculated: 0, // TODO(aljen): pass actual values 106 | received: 0, // TODO(aljen): pass actual values 107 | frame_hex, 108 | }) 109 | } else { 110 | RelayError::Frame(FrameError::Format { 111 | kind: FrameFormatKind::InvalidFormat, 112 | details, 113 | frame_data: None, 114 | }) 115 | } 116 | } 117 | } 118 | } 119 | 120 | pub fn client( 121 | kind: ClientErrorKind, 122 | client_addr: std::net::SocketAddr, 123 | details: impl Into, 124 | ) -> Self { 125 | RelayError::Client { 126 | kind, 127 | client_addr, 128 | details: details.into(), 129 | } 130 | } 131 | } 132 | 133 | impl From for RelayError { 134 | fn from(err: BackoffError) -> Self { 135 | RelayError::Connection(ConnectionError::Backoff(err)) 136 | } 137 | } 138 | 139 | impl From for RelayError { 140 | fn from(err: RtsError) -> Self { 141 | RelayError::Transport(TransportError::Rts(err)) 142 | } 143 | } 144 | 145 | impl From for RelayError { 146 | fn from(err: config::ConfigError) -> Self { 147 | Self::Config(ConfigValidationError::config(err.to_string())) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/connection/manager.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; 2 | 3 | use tokio::sync::{Mutex, Semaphore, mpsc, oneshot}; 4 | use tracing::error; 5 | 6 | use crate::{ConnectionError, RelayError, config::ConnectionConfig}; 7 | 8 | use super::{ConnectionGuard, ConnectionStats, StatEvent}; 9 | 10 | /// TCP connection management 11 | #[derive(Debug)] 12 | pub struct Manager { 13 | /// Connection limit per IP 14 | per_ip_semaphores: Arc>>>, 15 | /// Global connection limit 16 | global_semaphore: Arc, 17 | /// Active connections counter per IP 18 | active_connections: Arc>>, 19 | /// Configuration 20 | config: ConnectionConfig, 21 | /// Stats event sender 22 | stats_tx: mpsc::Sender, 23 | } 24 | 25 | impl Manager { 26 | pub fn new(config: ConnectionConfig, stats_tx: mpsc::Sender) -> Self { 27 | Self { 28 | per_ip_semaphores: Arc::new(Mutex::new(HashMap::new())), 29 | global_semaphore: Arc::new(Semaphore::new(config.max_connections as usize)), 30 | active_connections: Arc::new(Mutex::new(HashMap::new())), 31 | config, 32 | stats_tx, 33 | } 34 | } 35 | 36 | /// Attempt to establish a new connection 37 | pub async fn accept_connection( 38 | self: &Arc, 39 | addr: SocketAddr, 40 | ) -> Result { 41 | // Check per IP limit if enabled 42 | let per_ip_permit = if let Some(per_ip_limit) = self.config.per_ip_limits { 43 | let mut semaphores = self.per_ip_semaphores.lock().await; 44 | 45 | let semaphore = semaphores 46 | .entry(addr) 47 | .or_insert_with(|| Arc::new(Semaphore::new(per_ip_limit as usize))); 48 | 49 | Some(semaphore.clone().try_acquire_owned().map_err(|_| { 50 | RelayError::Connection(ConnectionError::limit_exceeded(format!( 51 | "Per-IP limit ({}) reached for {}", 52 | per_ip_limit, addr 53 | ))) 54 | })?) 55 | } else { 56 | None 57 | }; 58 | 59 | // Check if the global limit has been exceeded 60 | let global_permit = self 61 | .global_semaphore 62 | .clone() 63 | .try_acquire_owned() 64 | .map_err(|_| { 65 | RelayError::Connection(ConnectionError::limit_exceeded( 66 | "Global connection limit reached", 67 | )) 68 | })?; 69 | 70 | // Increment active connections counter 71 | { 72 | let mut active_conns = self.active_connections.lock().await; 73 | let conn_count = active_conns.entry(addr).or_default(); 74 | *conn_count = conn_count.saturating_add(1); 75 | } 76 | 77 | // Notify stats manager about new connection 78 | if let Err(e) = self.stats_tx.send(StatEvent::ClientConnected(addr)).await { 79 | error!("Failed to send connection event to stats manager: {}", e); 80 | } 81 | 82 | Ok(ConnectionGuard { 83 | manager: Arc::clone(self), 84 | addr, 85 | _global_permit: global_permit, 86 | _per_ip_permit: per_ip_permit, 87 | }) 88 | } 89 | 90 | pub async fn get_connection_count(&self, addr: &SocketAddr) -> usize { 91 | self.active_connections 92 | .lock() 93 | .await 94 | .get(addr) 95 | .copied() 96 | .unwrap_or(0) 97 | } 98 | 99 | pub async fn get_total_connections(&self) -> usize { 100 | self.active_connections.lock().await.values().sum() 101 | } 102 | 103 | /// Updates statistics for a given request 104 | pub async fn record_request(&self, addr: SocketAddr, success: bool, duration: Duration) { 105 | if let Err(e) = self 106 | .stats_tx 107 | .send(StatEvent::RequestProcessed { 108 | addr, 109 | success, 110 | duration_ms: duration.as_millis() as u64, 111 | }) 112 | .await 113 | { 114 | error!("Failed to send request stats: {}", e); 115 | } 116 | } 117 | 118 | /// Gets complete connection statistics 119 | pub async fn get_stats(&self) -> Result { 120 | let (tx, rx) = oneshot::channel(); 121 | 122 | self.stats_tx 123 | .send(StatEvent::QueryConnectionStats { response_tx: tx }) 124 | .await 125 | .map_err(|_| { 126 | RelayError::Connection(ConnectionError::invalid_state( 127 | "Failed to query connection stats", 128 | )) 129 | })?; 130 | 131 | rx.await.map_err(|_| { 132 | RelayError::Connection(ConnectionError::invalid_state( 133 | "Failed to receive connection stats", 134 | )) 135 | }) 136 | } 137 | 138 | /// Cleans up idle connections 139 | pub(crate) async fn cleanup_idle_connections(&self) -> Result<(), RelayError> { 140 | // Cleanup is now handled by StatsManager, we just need to sync our active connections 141 | let (tx, rx) = oneshot::channel(); 142 | 143 | self.stats_tx 144 | .send(StatEvent::QueryConnectionStats { response_tx: tx }) 145 | .await 146 | .map_err(|_| { 147 | RelayError::Connection(ConnectionError::invalid_state( 148 | "Failed to query stats for cleanup", 149 | )) 150 | })?; 151 | 152 | let stats = rx.await.map_err(|_| { 153 | RelayError::Connection(ConnectionError::invalid_state( 154 | "Failed to receive stats for cleanup", 155 | )) 156 | })?; 157 | 158 | let mut active_conns = self.active_connections.lock().await; 159 | active_conns.retain(|addr, count| { 160 | if let Some(ip_stats) = stats.per_ip_stats.get(addr) { 161 | ip_stats.active_connections > 0 162 | } else { 163 | // If no stats exist, connection is considered inactive 164 | *count == 0 165 | } 166 | }); 167 | 168 | Ok(()) 169 | } 170 | 171 | pub(crate) fn decrease_connection_count(&self, addr: SocketAddr) { 172 | let mut active_conns = self 173 | .active_connections 174 | .try_lock() 175 | .expect("Failed to lock active_connections in guard drop"); 176 | 177 | if let Some(count) = active_conns.get_mut(&addr) { 178 | *count = count.saturating_sub(1); 179 | if *count == 0 { 180 | active_conns.remove(&addr); 181 | } 182 | } 183 | } 184 | 185 | pub fn stats_tx(&self) -> mpsc::Sender { 186 | self.stats_tx.clone() 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | use std::process; 3 | use std::sync::Arc; 4 | 5 | use clap::Parser; 6 | use time::UtcOffset; 7 | use tracing::{error, info}; 8 | use tracing_appender::{non_blocking, rolling}; 9 | use tracing_subscriber::{ 10 | EnvFilter, Layer, Registry, fmt::time::OffsetTime, layer::SubscriberExt, 11 | util::SubscriberInitExt, 12 | }; 13 | 14 | use modbus_relay::{ 15 | ModbusRelay, RelayConfig, RelayError, TransportError, errors::InitializationError, 16 | }; 17 | 18 | #[derive(Parser)] 19 | #[command(author, version, about, long_about = None)] 20 | struct Cli { 21 | #[command(flatten)] 22 | common: CommonArgs, 23 | } 24 | 25 | #[derive(clap::Args)] 26 | #[group(multiple = false)] 27 | struct CommonArgs { 28 | /// Path to config file 29 | #[arg(short, long, value_name = "FILE")] 30 | config: Option, 31 | 32 | /// Run in debug mode 33 | #[arg(short, long)] 34 | debug: bool, 35 | } 36 | 37 | pub fn setup_logging( 38 | config: &RelayConfig, 39 | ) -> Result<(impl Drop + use<>, impl Drop + use<>), RelayError> { 40 | let timer = OffsetTime::new( 41 | UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC), 42 | time::format_description::well_known::Rfc3339, 43 | ); 44 | 45 | let log_dir = PathBuf::from(&config.logging.log_dir); 46 | let include_location = config.logging.include_location; 47 | let thread_ids = config.logging.thread_ids; 48 | let thread_names = config.logging.thread_names; 49 | 50 | std::fs::create_dir_all(&log_dir).unwrap_or_else(|_| { 51 | eprintln!("Failed to create log directory {}", log_dir.display()); 52 | process::exit(1); 53 | }); 54 | 55 | // Non-blocking stdout 56 | let (stdout_writer, stdout_guard) = non_blocking(std::io::stdout()); 57 | 58 | // Rotating log writer 59 | let file_appender = rolling::daily(log_dir, "modbus-relay.log"); 60 | let (file_writer, file_guard) = non_blocking(file_appender); 61 | 62 | // Environment-based filter 63 | let mut stdout_env_filter = EnvFilter::builder() 64 | .with_default_directive(config.logging.get_level_filter().into()) 65 | .from_env_lossy(); 66 | 67 | let mut file_env_filter = EnvFilter::builder() 68 | .with_default_directive(config.logging.get_level_filter().into()) 69 | .from_env_lossy(); 70 | 71 | if config.logging.trace_frames { 72 | stdout_env_filter = stdout_env_filter 73 | .add_directive("modbus_relay::protocol=trace".parse().unwrap()) 74 | .add_directive("modbus_relay::transport=trace".parse().unwrap()); 75 | 76 | file_env_filter = file_env_filter 77 | .add_directive("modbus_relay::protocol=trace".parse().unwrap()) 78 | .add_directive("modbus_relay::transport=trace".parse().unwrap()); 79 | } 80 | 81 | // Log layer for stdout 82 | let stdout_layer = tracing_subscriber::fmt::layer() 83 | .with_writer(stdout_writer) 84 | .with_target(false) 85 | .with_thread_ids(thread_ids) 86 | .with_thread_names(thread_names) 87 | .with_file(include_location) 88 | .with_line_number(include_location) 89 | .with_level(true) 90 | .with_timer(timer.clone()) 91 | .with_filter(stdout_env_filter); 92 | 93 | // Log layer for file 94 | let file_layer = tracing_subscriber::fmt::layer() 95 | .with_writer(file_writer) 96 | .with_target(false) 97 | .with_thread_ids(thread_ids) 98 | .with_thread_names(thread_names) 99 | .with_file(include_location) 100 | .with_line_number(include_location) 101 | .with_level(true) 102 | .with_timer(timer) 103 | .with_filter(file_env_filter); 104 | 105 | // Combine all layers 106 | Registry::default() 107 | .with(stdout_layer) 108 | .with(file_layer) 109 | .try_init() 110 | .map_err(|e| { 111 | RelayError::Init(InitializationError::logging(format!( 112 | "Failed to initialize logging: {}", 113 | e 114 | ))) 115 | })?; 116 | 117 | Ok((stdout_guard, file_guard)) 118 | } 119 | 120 | #[tokio::main] 121 | async fn main() { 122 | let cli = Cli::parse(); 123 | 124 | // Load configuration 125 | let config = if let Some(config_path) = &cli.common.config { 126 | RelayConfig::from_file(config_path.clone()) 127 | } else { 128 | RelayConfig::new() 129 | }; 130 | 131 | let config = match config { 132 | Ok(config) => config, 133 | Err(e) => { 134 | eprintln!("Failed to load configuration: {:#}", e); 135 | process::exit(1); 136 | } 137 | }; 138 | 139 | // Setup logging based on configuration 140 | let (_stdout_guard, _file_guard) = match setup_logging(&config) { 141 | Ok(guards) => guards, 142 | Err(e) => { 143 | eprintln!("Failed to setup logging: {:#}", e); 144 | process::exit(1); 145 | } 146 | }; 147 | 148 | info!("Starting Modbus Relay..."); 149 | 150 | if let Err(e) = run(config).await { 151 | error!("Fatal error: {:#}", e); 152 | if let Some(RelayError::Transport(TransportError::Io { details, .. })) = 153 | e.downcast_ref::() 154 | && details.contains("serial port") 155 | { 156 | error!( 157 | "Hint: Make sure the configured serial port exists and you have permission to access it" 158 | ); 159 | #[cfg(target_os = "macos")] 160 | error!( 161 | "Hint: On macOS, you might need to install the driver from https://www.silabs.com/developers/usb-to-uart-bridge-vcp-drivers" 162 | ); 163 | #[cfg(target_os = "linux")] 164 | error!( 165 | "Hint: On Linux, you might need to add your user to the dialout group: sudo usermod -a -G dialout $USER" 166 | ); 167 | } 168 | process::exit(1); 169 | } 170 | } 171 | 172 | async fn run(config: RelayConfig) -> Result<(), Box> { 173 | let relay = Arc::new(ModbusRelay::new(config)?); 174 | 175 | let relay_clone = Arc::clone(&relay); 176 | 177 | let shutdown_task = tokio::spawn(async move { 178 | let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) 179 | .expect("Failed to create SIGTERM signal handler"); 180 | let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) 181 | .expect("Failed to create SIGINT signal handler"); 182 | tokio::select! { 183 | _ = sigterm.recv() => info!("Received SIGTERM"), 184 | _ = sigint.recv() => info!("Received SIGINT"), 185 | } 186 | 187 | if let Err(e) = relay_clone.shutdown().await { 188 | error!("Error during shutdown: {}", e); 189 | } 190 | }); 191 | 192 | relay.run().await?; 193 | 194 | info!("Waiting for shutdown to complete..."); 195 | 196 | shutdown_task.await?; 197 | 198 | info!("Modbus Relay stopped"); 199 | 200 | Ok(()) 201 | } 202 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # Modbus Relay - TODO List 2 | 3 | ## 1. Error Handling [DONE] 4 | 5 | - [x] Enhanced error types and hierarchy 6 | - [x] Proper error conversion implementations 7 | - [x] Error recovery strategies (connection retries with backoff) 8 | - [x] Context-aware error reporting (detailed error types with context) 9 | - [x] Error metrics and monitoring (via ConnectionManager stats) 10 | - [x] Custom error middleware for better logging (structured errors with tracing) 11 | 12 | ## 2. Connection Management [MOSTLY DONE] 13 | 14 | - [x] Maximum connections limit (global and per-IP) 15 | - [x] Connection backoff strategy 16 | - [x] Basic connection handling 17 | - [x] Connection stats tracking 18 | - [x] Proper error handling in connection management 19 | - [x] Fail-fast behavior for connection limits 20 | - [x] Connection reuse optimization 21 | - [ ] TCP connection pooling 22 | - [ ] Advanced connection timeouts and keep-alive 23 | - [ ] Enhanced health checks 24 | 25 | ## 3. Protocol Handling [DONE] 26 | 27 | - [x] Separation of concerns (TCP vs Modbus logic) 28 | - [x] ModbusProcessor implementation 29 | - [x] Frame handling 30 | - [x] Protocol error handling 31 | - [x] RTU-TCP conversion 32 | - [x] RTS control (configurable) 33 | - [x] Frame validation and CRC checking 34 | - [x] Advanced error handling for protocol errors 35 | 36 | ## 4. Testing [IN PROGRESS] 37 | 38 | - [x] Basic unit tests for error handling 39 | - [x] Connection management tests 40 | - [x] Proper separation of test responsibilities 41 | - [x] Error handling tests 42 | - [ ] ModbusProcessor tests 43 | - [ ] TCP connection handling tests 44 | - [ ] Integration tests 45 | - [ ] Complete unit test coverage 46 | - [ ] Property-based testing 47 | - [ ] Fuzz testing for protocol handling 48 | - [ ] Benchmark tests 49 | - [ ] Load tests 50 | - [ ] Chaos testing 51 | 52 | ## 5. Performance Optimization [IN PROGRESS] 53 | 54 | - [x] Efficient frame processing 55 | - [x] Optimized error handling 56 | - [x] Smart buffer sizing 57 | - [ ] Buffer pooling 58 | - [ ] Zero-copy frame handling 59 | - [ ] Batch request processing 60 | - [ ] Response caching for read-only registers 61 | - [ ] Configurable thread/task pool 62 | - [ ] Memory usage optimization 63 | 64 | ## 6. Monitoring & Metrics [MOSTLY DONE] 65 | 66 | - [x] Basic connection stats 67 | - [x] Error rate tracking 68 | - [x] Connection tracking 69 | - [x] Detailed error reporting 70 | - [x] Request/response timing metrics 71 | - [x] Frame statistics 72 | - [ ] Prometheus metrics integration 73 | - [ ] System resource usage monitoring 74 | - [ ] Alerting integration 75 | 76 | ## 7. Reliability Features [MOSTLY DONE] 77 | 78 | - [x] Basic rate limiting (per-IP limits) 79 | - [x] Connection backoff 80 | - [x] Basic error recovery 81 | - [x] Proper error propagation 82 | - [x] Advanced backpressure handling 83 | - [x] RTS control with timing configuration 84 | - [ ] Circuit breaker for RTU device 85 | - [ ] Automatic reconnection 86 | - [ ] Request retry mechanism 87 | - [ ] Request prioritization 88 | 89 | ## 8. Configuration [MOSTLY DONE] 90 | 91 | - [x] Basic configuration validation 92 | - [x] JSON config support 93 | - [x] Enhanced config validation with detailed errors 94 | - [x] Environment variable support 95 | - [x] Feature flags (RTS support) 96 | - [x] Serial port configuration 97 | - [x] TCP configuration 98 | - [x] Timing configuration 99 | - [ ] Dynamic configuration reloading 100 | - [ ] YAML/TOML support 101 | - [ ] Secrets management 102 | 103 | ## 9. Security [PARTIALLY DONE] 104 | 105 | - [x] Basic rate limiting 106 | - [x] Request validation 107 | - [x] Frame validation 108 | - [ ] TLS support for TCP connections 109 | - [ ] Authentication/Authorization 110 | - [ ] Enhanced rate limiting 111 | - [ ] IP whitelisting 112 | - [ ] Security headers 113 | - [ ] Audit logging 114 | 115 | ## 10. Logging & Debugging [MOSTLY DONE] 116 | 117 | - [x] Basic structured logging 118 | - [x] Debug protocol traces 119 | - [x] Detailed error logging 120 | - [x] Log context propagation 121 | - [x] Request/response tracing 122 | - [ ] Log rotation 123 | - [ ] Request ID tracking 124 | - [ ] Performance profiling 125 | - [ ] Diagnostic endpoints 126 | - [ ] Audit trail 127 | 128 | ## 11. Documentation [IN PROGRESS] 129 | 130 | - [x] Error handling documentation 131 | - [x] Configuration documentation 132 | - [x] Basic usage examples 133 | - [ ] API documentation 134 | - [ ] Configuration guide 135 | - [ ] Deployment guide 136 | - [ ] Performance tuning guide 137 | - [ ] Security best practices 138 | - [ ] Troubleshooting guide 139 | - [ ] Architecture documentation 140 | - [ ] Contributing guidelines 141 | 142 | ## 12. Administrative Features 143 | 144 | - [ ] Admin API 145 | - [ ] Statistics endpoint 146 | - [ ] Configuration management endpoint 147 | - [ ] Connection management 148 | - [ ] Log level control 149 | - [ ] Feature flag management 150 | - [ ] Health check endpoints 151 | - [ ] Metrics endpoints 152 | 153 | ## 13. Development Tools 154 | 155 | - [ ] Development environment setup 156 | - [ ] CI/CD pipeline 157 | - [ ] Release automation 158 | - [ ] Docker support 159 | - [ ] Kubernetes manifests 160 | - [ ] Development workflow documentation 161 | - [ ] Test data generators 162 | - [ ] Protocol simulators 163 | 164 | ## 14. Protocol Enhancements [PARTIALLY DONE] 165 | 166 | - [x] Support for basic Modbus function codes 167 | - [x] Proper error reporting 168 | - [x] Protocol separation of concerns 169 | - [x] Protocol conformance 170 | - [ ] Support for all Modbus function codes 171 | - [ ] Protocol conformance testing 172 | - [ ] Custom function code handling 173 | - [ ] Protocol extensions 174 | - [ ] Protocol version negotiation 175 | 176 | ## 15. Operational Features [PARTIALLY DONE] 177 | 178 | - [x] Basic graceful shutdown 179 | - [x] Error recovery mechanisms 180 | - [x] Configurable timeouts 181 | - [ ] Enhanced hot reload 182 | - [ ] Backup/Restore functionality 183 | - [ ] Data persistence (if needed) 184 | - [ ] Migration tools 185 | - [ ] Maintenance mode 186 | - [ ] Resource cleanup 187 | 188 | ## 16. Integration [PLANNED] 189 | 190 | - [ ] OpenTelemetry integration 191 | - [ ] Metrics export 192 | - [ ] Log aggregation 193 | - [ ] Alert manager integration 194 | - [ ] Service discovery 195 | - [ ] Load balancer integration 196 | - [ ] Monitoring system integration 197 | - [ ] Centralized logging 198 | 199 | ## Next Priority Tasks 200 | 201 | 1. Complete Testing: 202 | - Add ModbusProcessor tests 203 | - Add TCP connection handling tests 204 | - Implement integration tests 205 | - Add performance benchmarks 206 | 207 | 2. Enhance Monitoring: 208 | - Implement Prometheus metrics 209 | - Add system resource monitoring 210 | - Implement alerting system 211 | - Add OpenTelemetry support 212 | 213 | 3. Security Improvements: 214 | - Add TLS support 215 | - Implement authentication 216 | - Add IP whitelisting 217 | - Implement audit logging 218 | 219 | 4. Documentation & Development: 220 | - Complete API documentation 221 | - Add deployment guide 222 | - Add performance tuning guide 223 | - Set up CI/CD pipeline 224 | - Add Docker support 225 | 226 | Each feature should be implemented with: 227 | - Clear documentation 228 | - Tests (with proper responsibility separation) 229 | - Metrics 230 | - Configuration options 231 | - Error handling 232 | - Logging 233 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | env: 9 | CARGO_TERM_COLOR: always 10 | CARGO_INCREMENTAL: 0 11 | CARGO_NET_RETRY: 10 12 | RUSTUP_MAX_RETRIES: 10 13 | RUST_BACKTRACE: short 14 | PKG_CONFIG_ALLOW_CROSS: 1 15 | 16 | jobs: 17 | # First run the same checks as in CI 18 | check: 19 | name: Check 20 | uses: ./.github/workflows/ci.yml 21 | 22 | build-deb: 23 | name: Build Debian Package 24 | needs: check 25 | runs-on: ubuntu-latest 26 | strategy: 27 | matrix: 28 | target: 29 | - armv7-unknown-linux-gnueabihf 30 | - aarch64-unknown-linux-gnu 31 | - x86_64-unknown-linux-gnu 32 | include: 33 | - target: armv7-unknown-linux-gnueabihf 34 | arch: armhf 35 | - target: aarch64-unknown-linux-gnu 36 | arch: arm64 37 | - target: x86_64-unknown-linux-gnu 38 | arch: amd64 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | 43 | - name: Install common dependencies 44 | run: | 45 | sudo apt-get update 46 | sudo apt-get install -y pkg-config 47 | 48 | - name: Install target specific dependencies 49 | run: | 50 | case ${{ matrix.target }} in 51 | "x86_64-unknown-linux-gnu") 52 | sudo apt-get update 53 | sudo apt-get install -y libudev-dev dpkg-dev 54 | ;; 55 | "armv7-unknown-linux-gnueabihf") 56 | sudo apt-get update 57 | sudo apt-get install -y binutils-arm-linux-gnueabihf crossbuild-essential-armhf 58 | ;; 59 | "aarch64-unknown-linux-gnu") 60 | sudo apt-get update 61 | sudo apt-get install -y binutils-aarch64-linux-gnu crossbuild-essential-arm64 62 | ;; 63 | esac 64 | 65 | - name: Set library path 66 | if: matrix.target == 'x86_64-unknown-linux-gnu' 67 | run: | 68 | echo "LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH" >> $GITHUB_ENV 69 | echo "LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LIBRARY_PATH" >> $GITHUB_ENV 70 | 71 | - name: Install Rust toolchain 72 | uses: dtolnay/rust-toolchain@stable 73 | with: 74 | targets: ${{ matrix.target }} 75 | 76 | - name: Install cross 77 | uses: taiki-e/install-action@cross 78 | 79 | - name: Cache cargo 80 | uses: actions/cache@v3 81 | with: 82 | path: | 83 | ~/.cargo/registry 84 | ~/.cargo/git 85 | target 86 | key: ${{ runner.os }}-cargo-${{ matrix.target }}-${{ hashFiles('**/Cargo.lock') }} 87 | restore-keys: | 88 | ${{ runner.os }}-cargo-${{ matrix.target }}- 89 | 90 | - name: Cache cross 91 | if: matrix.target != 'x86_64-unknown-linux-gnu' 92 | uses: actions/cache@v3 93 | with: 94 | path: ~/.cargo/.cross 95 | key: ${{ runner.os }}-cross-${{ matrix.target }}-${{ hashFiles('Cross.toml') }} 96 | restore-keys: | 97 | ${{ runner.os }}-cross-${{ matrix.target }}- 98 | 99 | - name: Install cargo-deb 100 | run: cargo install cargo-deb 101 | 102 | - name: Build and package 103 | env: 104 | PKG_CONFIG_ALLOW_CROSS: "1" 105 | run: | 106 | if [ "${{ matrix.target }}" = "x86_64-unknown-linux-gnu" ]; then 107 | # Native build 108 | cargo build --release --target ${{ matrix.target }} 109 | cargo deb --target ${{ matrix.target }} 110 | else 111 | # Cross compilation requires special paths 112 | if [ "${{ matrix.target }}" = "armv7-unknown-linux-gnueabihf" ]; then 113 | PKG_PATH="/usr/lib/arm-linux-gnueabihf/pkgconfig" 114 | LD_PATH="/usr/arm-linux-gnueabihf/lib" 115 | elif [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then 116 | PKG_PATH="/usr/lib/aarch64-linux-gnu/pkgconfig" 117 | LD_PATH="/usr/aarch64-linux-gnu/lib" 118 | fi 119 | 120 | PKG_CONFIG_PATH="$PKG_PATH" \ 121 | PKG_CONFIG_SYSROOT_DIR="/usr" \ 122 | PKG_CONFIG_LIBDIR="$PKG_PATH" \ 123 | cross build --release --target ${{ matrix.target }} 124 | 125 | # Setup cargo config for cross-compilation strip 126 | mkdir -p .cargo 127 | cp dist/debian/cargo-config.toml .cargo/config.toml 128 | 129 | LD_LIBRARY_PATH="$LD_PATH" \ 130 | cargo deb --no-build --target ${{ matrix.target }} 131 | fi 132 | 133 | - name: Upload artifacts 134 | uses: actions/upload-artifact@v4 135 | with: 136 | name: deb-${{ matrix.arch }} 137 | path: target/${{ matrix.target }}/debian/*.deb 138 | 139 | build-arch: 140 | name: Build Arch Package 141 | needs: check 142 | runs-on: ubuntu-latest 143 | container: 144 | image: archlinux:base-devel 145 | options: --privileged 146 | 147 | steps: 148 | - uses: actions/checkout@v4 149 | 150 | - name: Install dependencies 151 | run: | 152 | # Update system 153 | pacman -Syu --noconfirm 154 | 155 | # Install common dependencies 156 | pacman -S --noconfirm \ 157 | git \ 158 | rustup \ 159 | cargo \ 160 | pkg-config \ 161 | systemd \ 162 | systemd-libs 163 | 164 | - name: Setup Rust 165 | run: | 166 | rustup default stable 167 | rustup target add x86_64-unknown-linux-gnu 168 | 169 | - name: Install Rust toolchain 170 | uses: dtolnay/rust-toolchain@stable 171 | with: 172 | targets: x86_64-unknown-linux-gnu 173 | 174 | - name: Cache cargo 175 | uses: actions/cache@v3 176 | with: 177 | path: | 178 | ~/.cargo/registry 179 | ~/.cargo/git 180 | target 181 | key: ${{ runner.os }}-arch-cargo-x86_64-${{ hashFiles('**/Cargo.lock') }} 182 | restore-keys: | 183 | ${{ runner.os }}-arch-cargo-x86_64- 184 | 185 | - name: Prepare PKGBUILD 186 | run: | 187 | VERSION=$(grep -m1 'version =' Cargo.toml | cut -d '"' -f2) 188 | sed -i "s/pkgver=.*/pkgver=$VERSION/" dist/arch/PKGBUILD 189 | mkdir -p /tmp/pkg/{cargo,src} 190 | cp -r . /tmp/pkg/src/modbus-relay 191 | cp dist/arch/PKGBUILD /tmp/pkg 192 | 193 | - name: Build package 194 | run: | 195 | cd /tmp/pkg 196 | chown -R nobody:nobody . 197 | CARGO_HOME=/tmp/pkg/cargo runuser -p -u nobody -g root -- makepkg -s --noconfirm 198 | 199 | - name: Upload artifacts 200 | uses: actions/upload-artifact@v4 201 | with: 202 | name: pkg-x86_64 203 | path: /tmp/pkg/*.pkg.tar.zst 204 | 205 | create-release: 206 | name: Create Release 207 | needs: [build-deb, build-arch] 208 | runs-on: ubuntu-latest 209 | permissions: 210 | contents: write 211 | steps: 212 | - uses: actions/checkout@v4 213 | 214 | - name: Download artifacts 215 | uses: actions/download-artifact@v4 216 | with: 217 | path: artifacts 218 | 219 | - name: List files 220 | run: | 221 | ls -R artifacts/ 222 | echo "Files to be uploaded:" 223 | find artifacts -type f -name "*.deb" -o -name "*.pkg.tar.zst" 224 | 225 | - name: Create Release 226 | uses: softprops/action-gh-release@v1 227 | with: 228 | files: | 229 | artifacts/**/*.deb 230 | artifacts/**/*.pkg.tar.zst 231 | generate_release_notes: true 232 | -------------------------------------------------------------------------------- /src/http_api.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::SystemTime}; 2 | 3 | use axum::{Json, Router, extract::State, http::StatusCode, response::IntoResponse, routing::get}; 4 | use serde::Serialize; 5 | use tokio::sync::{broadcast, oneshot}; 6 | use tracing::info; 7 | 8 | use crate::{ConnectionManager, connection::StatEvent}; 9 | 10 | #[derive(Debug, Serialize)] 11 | struct HealthResponse { 12 | status: &'static str, 13 | tcp_connections: u32, 14 | rtu_status: &'static str, 15 | } 16 | 17 | #[derive(Debug, Serialize)] 18 | struct IpStatsResponse { 19 | active_connections: usize, 20 | total_requests: u64, 21 | total_errors: u64, 22 | avg_response_time_ms: u64, 23 | last_active: SystemTime, 24 | last_error: Option, 25 | } 26 | 27 | #[derive(Debug, Serialize)] 28 | struct StatsResponse { 29 | // Basic stats 30 | total_connections: u64, 31 | active_connections: u32, 32 | total_requests: u64, 33 | total_errors: u64, 34 | requests_per_second: f64, 35 | avg_response_time_ms: u64, 36 | 37 | // Stats per IP 38 | per_ip_stats: HashMap, 39 | } 40 | 41 | type ApiState = Arc; 42 | 43 | async fn health_handler(State(state): State) -> impl IntoResponse { 44 | let (tx, rx) = oneshot::channel(); 45 | 46 | if (state 47 | .stats_tx() 48 | .send(StatEvent::QueryConnectionStats { response_tx: tx }) 49 | .await) 50 | .is_err() 51 | { 52 | return ( 53 | StatusCode::INTERNAL_SERVER_ERROR, 54 | Json(HealthResponse { 55 | status: "error", 56 | tcp_connections: 0, 57 | rtu_status: "unknown", 58 | }), 59 | ); 60 | } 61 | 62 | match rx.await { 63 | Ok(stats) => { 64 | ( 65 | StatusCode::OK, 66 | Json(HealthResponse { 67 | status: "ok", 68 | tcp_connections: stats.active_connections as u32, 69 | rtu_status: "ok", // TODO(aljen): Implement RTU status check 70 | }), 71 | ) 72 | } 73 | Err(_) => ( 74 | StatusCode::INTERNAL_SERVER_ERROR, 75 | Json(HealthResponse { 76 | status: "error", 77 | tcp_connections: 0, 78 | rtu_status: "unknown", 79 | }), 80 | ), 81 | } 82 | } 83 | 84 | async fn stats_handler(State(state): State) -> impl IntoResponse { 85 | let (tx, rx) = oneshot::channel(); 86 | 87 | if (state 88 | .stats_tx() 89 | .send(StatEvent::QueryConnectionStats { response_tx: tx }) 90 | .await) 91 | .is_err() 92 | { 93 | return ( 94 | StatusCode::INTERNAL_SERVER_ERROR, 95 | Json(StatsResponse { 96 | total_connections: 0, 97 | active_connections: 0, 98 | total_requests: 0, 99 | total_errors: 0, 100 | requests_per_second: 0.0, 101 | avg_response_time_ms: 0, 102 | per_ip_stats: HashMap::new(), 103 | }), 104 | ); 105 | } 106 | 107 | match rx.await { 108 | Ok(stats) => { 109 | let per_ip_stats = stats 110 | .per_ip_stats 111 | .into_iter() 112 | .map(|(addr, ip_stats)| { 113 | ( 114 | addr, 115 | IpStatsResponse { 116 | active_connections: ip_stats.active_connections, 117 | total_requests: ip_stats.total_requests, 118 | total_errors: ip_stats.total_errors, 119 | avg_response_time_ms: ip_stats.avg_response_time_ms, 120 | last_active: ip_stats.last_active, 121 | last_error: ip_stats.last_error, 122 | }, 123 | ) 124 | }) 125 | .collect(); 126 | 127 | ( 128 | StatusCode::OK, 129 | Json(StatsResponse { 130 | total_connections: stats.total_connections, 131 | active_connections: stats.active_connections as u32, 132 | total_requests: stats.total_requests, 133 | total_errors: stats.total_errors, 134 | requests_per_second: stats.requests_per_second, 135 | avg_response_time_ms: stats.avg_response_time_ms, 136 | per_ip_stats, 137 | }), 138 | ) 139 | } 140 | Err(_) => ( 141 | StatusCode::INTERNAL_SERVER_ERROR, 142 | Json(StatsResponse { 143 | total_connections: 0, 144 | active_connections: 0, 145 | total_requests: 0, 146 | total_errors: 0, 147 | requests_per_second: 0.0, 148 | avg_response_time_ms: 0, 149 | per_ip_stats: HashMap::new(), 150 | }), 151 | ), 152 | } 153 | } 154 | 155 | pub async fn start_http_server( 156 | address: String, 157 | port: u16, 158 | manager: Arc, 159 | mut shutdown_rx: broadcast::Receiver<()>, 160 | ) -> Result<(), Box> { 161 | let app = Router::new() 162 | .route("/health", get(health_handler)) 163 | .route("/stats", get(stats_handler)) 164 | .with_state(manager); 165 | 166 | let addr = format!("{}:{}", address, port); 167 | let listener = tokio::net::TcpListener::bind(&addr).await?; 168 | 169 | info!("HTTP server listening on {}", addr); 170 | 171 | axum::serve(listener, app) 172 | .with_graceful_shutdown(async move { 173 | let _ = shutdown_rx.recv().await; 174 | info!("HTTP server shutting down"); 175 | }) 176 | .await?; 177 | 178 | info!("HTTP server shutdown complete"); 179 | 180 | Ok(()) 181 | } 182 | 183 | #[cfg(test)] 184 | mod tests { 185 | use crate::{ConnectionConfig, StatsManager}; 186 | 187 | use super::*; 188 | use axum::body::Body; 189 | use axum::http::Request; 190 | use tokio::sync::Mutex; 191 | use tower::ServiceExt; 192 | 193 | #[tokio::test] 194 | async fn test_health_endpoint() { 195 | // Create a test stats manager 196 | let config = ConnectionConfig::default(); 197 | let stats_config = crate::StatsConfig::default(); 198 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 199 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 200 | 201 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 202 | 203 | let stats_handle = tokio::spawn({ 204 | async move { 205 | let mut stats_manager = stats_manager.lock().await; 206 | stats_manager.run(shutdown_rx).await; 207 | } 208 | }); 209 | 210 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 211 | 212 | // Build test app 213 | let app = Router::new() 214 | .route("/health", get(health_handler)) 215 | .with_state(manager); 216 | 217 | // Create test request 218 | let req = Request::builder() 219 | .uri("/health") 220 | .body(Body::empty()) 221 | .unwrap(); 222 | 223 | // Get response 224 | let response = app.oneshot(req).await.unwrap(); 225 | 226 | assert_eq!(response.status(), StatusCode::OK); 227 | 228 | shutdown_tx.send(true).unwrap(); 229 | stats_handle.await.unwrap(); 230 | } 231 | 232 | #[tokio::test] 233 | async fn test_stats_endpoint() { 234 | let config = ConnectionConfig::default(); 235 | let stats_config = crate::StatsConfig::default(); 236 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 237 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 238 | 239 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 240 | 241 | let stats_handle = tokio::spawn({ 242 | async move { 243 | let mut stats_manager = stats_manager.lock().await; 244 | stats_manager.run(shutdown_rx).await; 245 | } 246 | }); 247 | 248 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 249 | 250 | let app = Router::new() 251 | .route("/stats", get(stats_handler)) 252 | .with_state(manager); 253 | 254 | let req = Request::builder() 255 | .uri("/stats") 256 | .body(Body::empty()) 257 | .unwrap(); 258 | 259 | let response = app.oneshot(req).await.unwrap(); 260 | 261 | assert_eq!(response.status(), StatusCode::OK); 262 | 263 | shutdown_tx.send(true).unwrap(); 264 | stats_handle.await.unwrap(); 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /src/stats_manager.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::SystemTime}; 2 | 3 | use tokio::sync::{Mutex, mpsc}; 4 | use tracing::{debug, info, warn}; 5 | 6 | use crate::{ClientStats, ConnectionStats, config::StatsConfig, connection::StatEvent}; 7 | 8 | pub struct StatsManager { 9 | stats: Arc>>, 10 | event_rx: mpsc::Receiver, 11 | config: StatsConfig, 12 | total_connections: u64, 13 | } 14 | 15 | impl StatsManager { 16 | pub fn new(config: StatsConfig) -> (Self, mpsc::Sender) { 17 | let (tx, rx) = mpsc::channel(config.max_events_per_second as usize); 18 | 19 | let manager = Self { 20 | stats: Arc::new(Mutex::new(HashMap::new())), 21 | event_rx: rx, 22 | config, 23 | total_connections: 0, 24 | }; 25 | 26 | (manager, tx) 27 | } 28 | 29 | pub async fn run(&mut self, mut shutdown_rx: tokio::sync::watch::Receiver) { 30 | let mut cleanup_interval = tokio::time::interval(self.config.cleanup_interval); 31 | 32 | loop { 33 | tokio::select! { 34 | shutdown = shutdown_rx.changed() => { 35 | match shutdown { 36 | Ok(_) => { 37 | info!("Stats manager shutting down"); 38 | // Ensure all events are processed before shutting down 39 | while let Ok(event) = self.event_rx.try_recv() { 40 | self.handle_event(event).await; 41 | } 42 | break; 43 | } 44 | Err(e) => { 45 | warn!("Shutdown channel closed: {}", e); 46 | break; 47 | } 48 | } 49 | } 50 | 51 | Some(event) = self.event_rx.recv() => { 52 | self.handle_event(event).await; 53 | } 54 | 55 | _ = cleanup_interval.tick() => { 56 | self.cleanup_idle_stats().await; 57 | } 58 | } 59 | } 60 | 61 | info!("Stats manager shutdown complete"); 62 | } 63 | 64 | async fn handle_event(&mut self, event: StatEvent) { 65 | let mut stats = self.stats.lock().await; 66 | 67 | match event { 68 | StatEvent::ClientConnected(addr) => { 69 | let client_stats = stats.entry(addr).or_default(); 70 | client_stats.active_connections = client_stats.active_connections.saturating_add(1); 71 | client_stats.last_active = SystemTime::now(); 72 | self.total_connections = self.total_connections.saturating_add(1); 73 | debug!("Client connected from {}", addr); 74 | } 75 | 76 | StatEvent::ClientDisconnected(addr) => { 77 | if let Some(client_stats) = stats.get_mut(&addr) { 78 | client_stats.active_connections = 79 | client_stats.active_connections.saturating_sub(1); 80 | client_stats.last_active = SystemTime::now(); 81 | debug!("Client disconnected from {}", addr); 82 | } 83 | } 84 | 85 | StatEvent::RequestProcessed { 86 | addr, 87 | success, 88 | duration_ms, 89 | } => { 90 | let client_stats = stats.entry(addr).or_default(); 91 | client_stats.total_requests = client_stats.total_requests.saturating_add(1); 92 | 93 | if !success { 94 | client_stats.total_errors = client_stats.total_errors.saturating_add(1); 95 | client_stats.last_error = Some(SystemTime::now()); 96 | } 97 | 98 | // Update average response time using exponential moving average 99 | const ALPHA: f64 = 0.1; // Smoothing factor 100 | 101 | if client_stats.avg_response_time_ms == 0 { 102 | client_stats.avg_response_time_ms = duration_ms; 103 | } else { 104 | let current_avg = client_stats.avg_response_time_ms as f64; 105 | client_stats.avg_response_time_ms = 106 | (current_avg + ALPHA * (duration_ms as f64 - current_avg)) as u64; 107 | } 108 | 109 | client_stats.last_active = SystemTime::now(); 110 | } 111 | 112 | StatEvent::QueryStats { addr, response_tx } => { 113 | if let Some(stats) = stats.get(&addr) 114 | && response_tx.send(stats.clone()).is_err() 115 | { 116 | warn!("Failed to send stats for {}", addr); 117 | } 118 | } 119 | 120 | StatEvent::QueryConnectionStats { response_tx } => { 121 | let conn_stats = ConnectionStats::from_client_stats(&stats); 122 | if response_tx.send(conn_stats).is_err() { 123 | warn!("Failed to send connection stats"); 124 | } 125 | } 126 | } 127 | } 128 | 129 | async fn cleanup_idle_stats(&self) { 130 | let mut stats = self.stats.lock().await; 131 | let now = SystemTime::now(); 132 | 133 | stats.retain(|addr, client_stats| { 134 | // Check if client has been idle for too long 135 | let is_idle = now 136 | .duration_since(client_stats.last_active) 137 | .map(|idle_time| idle_time <= self.config.idle_timeout) 138 | .unwrap_or(true); 139 | 140 | // Check if there was an error that's old enough to clean up 141 | let has_recent_error = client_stats 142 | .last_error 143 | .and_then(|last_error| now.duration_since(last_error).ok()) 144 | .map(|error_time| error_time <= self.config.error_timeout) 145 | .unwrap_or(false); 146 | 147 | let should_retain = is_idle || has_recent_error; 148 | 149 | if !should_retain { 150 | debug!( 151 | "Cleaning up stats for {}: {} connections, {} requests, {} errors", 152 | addr, 153 | client_stats.active_connections, 154 | client_stats.total_requests, 155 | client_stats.total_errors 156 | ); 157 | } 158 | 159 | should_retain 160 | }); 161 | } 162 | } 163 | 164 | #[cfg(test)] 165 | mod tests { 166 | use std::time::Duration; 167 | 168 | use super::*; 169 | use tokio::{sync::oneshot, time::sleep}; 170 | 171 | #[tokio::test] 172 | async fn test_client_lifecycle() { 173 | let config = StatsConfig::default(); 174 | let (mut manager, tx) = StatsManager::new(config); 175 | let addr = "127.0.0.1:8080".parse().unwrap(); 176 | 177 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 178 | let manager_handle = tokio::spawn(async move { 179 | manager.run(shutdown_rx).await; 180 | }); 181 | 182 | // Test connection 183 | tx.send(StatEvent::ClientConnected(addr)).await.unwrap(); 184 | 185 | // Test successful request 186 | tx.send(StatEvent::RequestProcessed { 187 | addr, 188 | success: true, 189 | duration_ms: Duration::from_millis(100).as_millis() as u64, 190 | }) 191 | .await 192 | .unwrap(); 193 | 194 | // Test failed request 195 | tx.send(StatEvent::RequestProcessed { 196 | addr, 197 | success: false, 198 | duration_ms: Duration::from_millis(150).as_millis() as u64, 199 | }) 200 | .await 201 | .unwrap(); 202 | 203 | sleep(Duration::from_millis(100)).await; 204 | 205 | // Query per-client stats 206 | let (response_tx, response_rx) = oneshot::channel(); 207 | tx.send(StatEvent::QueryStats { addr, response_tx }) 208 | .await 209 | .unwrap(); 210 | 211 | let stats = response_rx.await.unwrap(); 212 | assert_eq!(stats.active_connections, 1); 213 | assert_eq!(stats.total_requests, 2); 214 | assert_eq!(stats.total_errors, 1); 215 | 216 | // Query global stats 217 | let (response_tx, response_rx) = oneshot::channel(); 218 | tx.send(StatEvent::QueryConnectionStats { response_tx }) 219 | .await 220 | .unwrap(); 221 | 222 | let conn_stats = response_rx.await.unwrap(); 223 | assert_eq!(conn_stats.total_requests, 2); 224 | assert_eq!(conn_stats.total_errors, 1); 225 | 226 | // Cleanup 227 | shutdown_tx.send(true).unwrap(); 228 | manager_handle.await.unwrap(); 229 | } 230 | 231 | #[tokio::test] 232 | async fn test_cleanup_idle_stats() { 233 | let mut config = StatsConfig::default(); 234 | config.idle_timeout = Duration::from_millis(100); 235 | let (mut manager, tx) = StatsManager::new(config); 236 | let addr = "127.0.0.1:8080".parse().unwrap(); 237 | 238 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 239 | let manager_handle = tokio::spawn(async move { 240 | manager.run(shutdown_rx).await; 241 | }); 242 | 243 | // Add client and disconnect 244 | tx.send(StatEvent::ClientConnected(addr)).await.unwrap(); 245 | tx.send(StatEvent::ClientDisconnected(addr)).await.unwrap(); 246 | 247 | // Wait for idle timeout 248 | sleep(Duration::from_millis(200)).await; 249 | 250 | // Query stats - should be cleaned up 251 | let (response_tx, response_rx) = oneshot::channel(); 252 | tx.send(StatEvent::QueryConnectionStats { response_tx }) 253 | .await 254 | .unwrap(); 255 | 256 | let conn_stats = response_rx.await.unwrap(); 257 | assert_eq!(conn_stats.active_connections, 0); 258 | 259 | shutdown_tx.send(true).unwrap(); 260 | manager_handle.await.unwrap(); 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /src/connection/mod.rs: -------------------------------------------------------------------------------- 1 | mod backoff_strategy; 2 | mod events; 3 | mod guard; 4 | mod manager; 5 | mod stats; 6 | 7 | pub use backoff_strategy::BackoffStrategy; 8 | pub use events::StatEvent; 9 | pub use guard::ConnectionGuard; 10 | pub use manager::Manager as ConnectionManager; 11 | pub use stats::ClientStats; 12 | pub use stats::ConnectionStats; 13 | pub use stats::IpStats; 14 | 15 | #[cfg(test)] 16 | mod tests { 17 | use tokio::{ 18 | sync::{Mutex, mpsc}, 19 | time::sleep, 20 | }; 21 | 22 | use crate::{ 23 | ConnectionError, RelayError, StatsConfig, StatsManager, 24 | config::{BackoffConfig, ConnectionConfig}, 25 | }; 26 | 27 | use super::*; 28 | use std::{ 29 | collections::HashMap, 30 | net::{IpAddr, Ipv4Addr, SocketAddr}, 31 | sync::Arc, 32 | time::Duration, 33 | }; 34 | 35 | #[tokio::test] 36 | async fn test_connection_limits() { 37 | let config = ConnectionConfig { 38 | max_connections: 2, 39 | per_ip_limits: Some(1), 40 | idle_timeout: Duration::from_secs(60), 41 | error_timeout: Duration::from_secs(300), 42 | connect_timeout: Duration::from_secs(5), 43 | backoff: BackoffConfig::default(), 44 | }; 45 | 46 | let (stats_tx, _) = mpsc::channel(100); 47 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 48 | let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); 49 | 50 | // First connection should succeed 51 | let conn1 = manager.accept_connection(addr1).await; 52 | assert!(conn1.is_ok(), "First connection should succeed"); 53 | 54 | // Second connection from same IP should fail immediately (per-IP limit) 55 | let conn2 = manager.accept_connection(addr1).await; 56 | match conn2 { 57 | Err(RelayError::Connection(ConnectionError::LimitExceeded(msg))) => { 58 | assert!( 59 | msg.contains("127.0.0.1:1234"), 60 | "Wrong IP in error message: {}", 61 | msg 62 | ); 63 | return; // <-- Return here after checking error 64 | } 65 | other => panic!("Expected LimitExceeded error, got: {:?}", other), 66 | } 67 | } 68 | 69 | #[tokio::test] 70 | async fn test_connection_stats_after_limit() { 71 | let config = ConnectionConfig { 72 | max_connections: 1, 73 | per_ip_limits: Some(1), 74 | ..Default::default() 75 | }; 76 | 77 | let stats_config = StatsConfig::default(); 78 | 79 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 80 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 81 | 82 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 83 | 84 | let stats_handle = tokio::spawn({ 85 | async move { 86 | let mut stats_manager = stats_manager.lock().await; 87 | stats_manager.run(shutdown_rx).await; 88 | } 89 | }); 90 | 91 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 92 | 93 | let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); 94 | 95 | // First connection succeeds 96 | let conn = manager.accept_connection(addr).await.unwrap(); 97 | 98 | // Second connection fails 99 | let _err = manager.accept_connection(addr).await.unwrap_err(); 100 | 101 | // Check stats 102 | let stats = manager.get_stats().await.unwrap(); 103 | 104 | assert_eq!( 105 | stats.active_connections, 1, 106 | "Should have one active connection" 107 | ); 108 | assert_eq!( 109 | stats.total_connections, 1, 110 | "Should have one total connection" 111 | ); 112 | 113 | // Cleanup 114 | drop(conn); 115 | 116 | shutdown_tx.send(true).unwrap(); 117 | stats_handle.await.unwrap(); 118 | } 119 | 120 | #[tokio::test] 121 | async fn test_idle_connection_cleanup() { 122 | let config = ConnectionConfig { 123 | idle_timeout: Duration::from_millis(100), 124 | ..Default::default() 125 | }; 126 | 127 | let stats_config = StatsConfig { 128 | cleanup_interval: config.idle_timeout, 129 | idle_timeout: config.idle_timeout, 130 | error_timeout: config.error_timeout, 131 | max_events_per_second: 10000, // TODO(aljen): Make configurable 132 | }; 133 | 134 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 135 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 136 | 137 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 138 | 139 | let stats_handle = tokio::spawn({ 140 | async move { 141 | let mut stats_manager = stats_manager.lock().await; 142 | stats_manager.run(shutdown_rx).await; 143 | } 144 | }); 145 | 146 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 147 | let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); 148 | 149 | // Create a connection 150 | let _conn = manager.accept_connection(addr).await.unwrap(); 151 | 152 | // Verify connection is active 153 | let stats = manager.get_stats().await.unwrap(); 154 | assert_eq!(stats.active_connections, 1); 155 | 156 | // Wait for connection to become idle 157 | sleep(Duration::from_millis(200)).await; 158 | 159 | // Cleanup should work 160 | assert!(manager.cleanup_idle_connections().await.is_ok()); 161 | 162 | // Verify connection was cleaned up 163 | let stats = manager.get_stats().await.unwrap(); 164 | assert_eq!(stats.active_connections, 0); 165 | 166 | shutdown_tx.send(true).unwrap(); 167 | stats_handle.await.unwrap(); 168 | } 169 | 170 | #[tokio::test] 171 | async fn test_connection_guard_cleanup() { 172 | let config = ConnectionConfig::default(); 173 | 174 | let stats_config = StatsConfig { 175 | cleanup_interval: config.idle_timeout, 176 | idle_timeout: config.idle_timeout, 177 | error_timeout: config.error_timeout, 178 | max_events_per_second: 10000, // TODO(aljen): Make configurable 179 | }; 180 | 181 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 182 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 183 | 184 | let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); 185 | 186 | let stats_handle = tokio::spawn({ 187 | async move { 188 | let mut stats_manager = stats_manager.lock().await; 189 | stats_manager.run(shutdown_rx).await; 190 | } 191 | }); 192 | 193 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 194 | 195 | let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); 196 | 197 | { 198 | let guard = manager.accept_connection(addr).await.unwrap(); 199 | let stats = manager.get_stats().await.unwrap(); 200 | assert_eq!(stats.active_connections, 1); 201 | 202 | // Guard should clean up when dropped 203 | drop(guard); 204 | } 205 | 206 | // Wait a bit for async cleanup 207 | sleep(Duration::from_millis(50)).await; 208 | 209 | let stats = manager.get_stats().await.unwrap(); 210 | assert_eq!(stats.active_connections, 0); 211 | 212 | shutdown_tx.send(true).unwrap(); 213 | stats_handle.await.unwrap(); 214 | } 215 | 216 | #[tokio::test] 217 | async fn test_backoff_strategy() { 218 | let config = BackoffConfig { 219 | initial_interval: Duration::from_millis(100), 220 | max_interval: Duration::from_secs(1), 221 | multiplier: 2.0, 222 | max_retries: 3, 223 | }; 224 | 225 | let mut strategy = BackoffStrategy::new(config); 226 | 227 | // The first attempts should return increasing values 228 | assert_eq!(strategy.next_backoff().unwrap().as_millis(), 100); 229 | assert_eq!(strategy.next_backoff().unwrap().as_millis(), 200); 230 | assert_eq!(strategy.next_backoff().unwrap().as_millis(), 400); 231 | 232 | // After exhausting attempts, it should return None 233 | assert!(strategy.next_backoff().is_none()); 234 | 235 | // After reset, it should start from the beginning 236 | strategy.reset(); 237 | assert_eq!(strategy.next_backoff().unwrap().as_millis(), 100); 238 | } 239 | 240 | #[tokio::test] 241 | async fn test_connection_lifecycle() { 242 | let config = ConnectionConfig::default(); 243 | let (stats_tx, mut stats_rx) = mpsc::channel(100); 244 | let manager = Arc::new(ConnectionManager::new(config, stats_tx)); 245 | 246 | // Handle stats events in background 247 | tokio::spawn(async move { 248 | while let Some(event) = stats_rx.recv().await { 249 | match event { 250 | StatEvent::QueryConnectionStats { response_tx } => { 251 | let _ = response_tx.send(ConnectionStats { 252 | total_connections: 1, 253 | active_connections: 1, 254 | total_requests: 0, 255 | total_errors: 0, 256 | requests_per_second: 0.0, 257 | avg_response_time_ms: 0, 258 | per_ip_stats: HashMap::new(), 259 | }); 260 | } 261 | _ => {} 262 | } 263 | } 264 | }); 265 | 266 | let addr = "127.0.0.1:8080".parse().unwrap(); 267 | 268 | // Test connection acceptance 269 | let guard = manager.accept_connection(addr).await.unwrap(); 270 | assert_eq!(manager.get_connection_count(&addr).await, 1); 271 | 272 | // Test statistics 273 | let stats = manager.get_stats().await.unwrap(); 274 | assert_eq!(stats.active_connections, 1); 275 | 276 | // Test connection cleanup 277 | drop(guard); 278 | sleep(Duration::from_millis(100)).await; 279 | assert_eq!(manager.get_connection_count(&addr).await, 0); 280 | } 281 | } 282 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024-2025 Artur Wyszyński 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/rtu_transport.rs: -------------------------------------------------------------------------------- 1 | use std::time::{Duration, Instant}; 2 | 3 | #[cfg(any(target_os = "linux", target_os = "macos"))] 4 | use std::os::unix::io::AsRawFd; 5 | 6 | #[cfg(any(target_os = "linux", target_os = "macos"))] 7 | use libc::{TIOCM_RTS, TIOCMGET, TIOCMSET}; 8 | 9 | #[cfg(any(target_os = "linux", target_os = "macos"))] 10 | use serialport::TTYPort; 11 | 12 | use serialport::SerialPort; 13 | use tokio::sync::Mutex; 14 | use tracing::{info, trace}; 15 | 16 | use crate::{RtsError, RtsType}; 17 | 18 | use crate::{FrameErrorKind, IoOperation, RelayError, RtuConfig, TransportError}; 19 | 20 | pub struct RtuTransport { 21 | port: Mutex>, 22 | config: RtuConfig, 23 | trace_frames: bool, 24 | 25 | #[cfg(any(target_os = "linux", target_os = "macos"))] 26 | raw_fd: i32, 27 | } 28 | 29 | impl RtuTransport { 30 | pub fn new(config: &RtuConfig, trace_frames: bool) -> Result { 31 | info!("Opening serial port {}", config.serial_port_info()); 32 | 33 | // Explicitly open as TTYPort on Unix 34 | #[cfg(any(target_os = "linux", target_os = "macos"))] 35 | let tty_port: TTYPort = serialport::new(&config.device, config.baud_rate) 36 | .data_bits(config.data_bits.into()) 37 | .parity(config.parity.into()) 38 | .stop_bits(config.stop_bits.into()) 39 | .timeout(config.serial_timeout) 40 | .flow_control(serialport::FlowControl::None) 41 | .open_native() 42 | .map_err(|e| TransportError::Io { 43 | operation: IoOperation::Configure, 44 | details: format!("serial port {}", config.device), 45 | source: std::io::Error::other(e.description), 46 | })?; 47 | 48 | #[cfg(any(target_os = "linux", target_os = "macos"))] 49 | let raw_fd = tty_port.as_raw_fd(); 50 | 51 | #[cfg(any(target_os = "linux", target_os = "macos"))] 52 | let port: Box = Box::new(tty_port); 53 | 54 | #[cfg(not(any(target_os = "linux", target_os = "macos")))] 55 | let port = serialport::new(&config.rtu_device, config.rtu_baud_rate) 56 | .data_bits(config.data_bits.into()) 57 | .parity(config.parity.into()) 58 | .stop_bits(config.stop_bits.into()) 59 | .timeout(config.serial_timeout) 60 | .flow_control(serialport::FlowControl::None) 61 | .open() 62 | .map_err(|e| TransportError::Io { 63 | operation: IoOperation::Configure, 64 | details: format!("serial port {}", config.rtu_device), 65 | source: std::io::Error::new(std::io::ErrorKind::Other, e.description), 66 | })?; 67 | 68 | Ok(Self { 69 | port: Mutex::new(port), 70 | config: config.clone(), 71 | trace_frames, 72 | #[cfg(any(target_os = "linux", target_os = "macos"))] 73 | raw_fd, 74 | }) 75 | } 76 | 77 | pub async fn close(&self) -> Result<(), TransportError> { 78 | let port = self.port.lock().await; 79 | port.clear(serialport::ClearBuffer::All) 80 | .map_err(|e| TransportError::Io { 81 | operation: IoOperation::Flush, 82 | details: "Failed to clear buffers".to_string(), 83 | source: std::io::Error::other(e.description), 84 | })?; 85 | 86 | #[cfg(unix)] 87 | unsafe { 88 | if libc::close(self.raw_fd) != 0 { 89 | return Err(TransportError::Io { 90 | operation: IoOperation::Control, 91 | details: "Failed to close serial port".to_string(), 92 | source: std::io::Error::last_os_error(), 93 | }); 94 | } 95 | } 96 | 97 | Ok(()) 98 | } 99 | 100 | fn set_rts(&self, on: bool, trace_frames: bool) -> Result<(), TransportError> { 101 | let rts_span = tracing::info_span!( 102 | "rts_control", 103 | signal = if on { "HIGH" } else { "LOW" }, 104 | delay_us = self.config.rts_delay_us, 105 | ); 106 | let _enter = rts_span.enter(); 107 | 108 | unsafe { 109 | let mut flags = 0i32; 110 | 111 | // Get current flags 112 | if libc::ioctl(self.raw_fd, TIOCMGET, &mut flags) < 0 { 113 | let err = std::io::Error::last_os_error(); 114 | return Err(TransportError::Rts(RtsError::signal(format!( 115 | "Failed to get RTS flags: {} (errno: {})", 116 | err, 117 | err.raw_os_error().unwrap_or(-1) 118 | )))); 119 | } 120 | 121 | // Modify RTS flag 122 | if on { 123 | flags |= TIOCM_RTS; // Set RTS HIGH 124 | } else { 125 | flags &= !TIOCM_RTS; // Set RTS LOW 126 | } 127 | 128 | // Set new flags 129 | if libc::ioctl(self.raw_fd, TIOCMSET, &flags) < 0 { 130 | let err = std::io::Error::last_os_error(); 131 | return Err(TransportError::Rts(RtsError::signal(format!( 132 | "Failed to set RTS flags: {} (errno: {})", 133 | err, 134 | err.raw_os_error().unwrap_or(-1) 135 | )))); 136 | } 137 | 138 | if trace_frames { 139 | trace!("RTS set to {}", if on { "HIGH" } else { "LOW" }); 140 | } 141 | } 142 | 143 | Ok(()) 144 | } 145 | 146 | #[cfg(any(target_os = "linux", target_os = "macos"))] 147 | fn tc_flush(&self) -> Result<(), TransportError> { 148 | unsafe { 149 | if libc::tcflush(self.raw_fd, libc::TCIOFLUSH) != 0 { 150 | return Err(TransportError::Io { 151 | operation: IoOperation::Flush, 152 | details: format!( 153 | "Failed to flush serial port: {}", 154 | std::io::Error::last_os_error() 155 | ), 156 | source: std::io::Error::last_os_error(), 157 | }); 158 | } 159 | } 160 | Ok(()) 161 | } 162 | 163 | pub async fn transaction( 164 | &self, 165 | request: &[u8], 166 | response: &mut [u8], 167 | ) -> Result { 168 | if request.len() > self.config.max_frame_size as usize { 169 | return Err(RelayError::frame( 170 | FrameErrorKind::TooLong, 171 | format!("Request frame too long: {} bytes", request.len()), 172 | Some(request.to_vec()), 173 | )); 174 | } 175 | 176 | let expected_size = response.len(); 177 | 178 | if self.trace_frames { 179 | trace!("TX: {} bytes: {:02X?}", request.len(), request); 180 | trace!("Expected response size: {} bytes", expected_size); 181 | } 182 | 183 | let transaction_start = Instant::now(); 184 | 185 | let result = tokio::time::timeout(self.config.transaction_timeout, async { 186 | let mut port = self.port.lock().await; 187 | 188 | if self.config.rts_type != RtsType::None { 189 | if self.trace_frames { 190 | trace!("RTS -> TX mode"); 191 | } 192 | 193 | self.set_rts( 194 | self.config.rts_type.to_signal_level(true), 195 | self.trace_frames, 196 | )?; 197 | 198 | if self.config.rts_delay_us > 0 { 199 | if self.trace_frames { 200 | trace!("RTS -> TX mode [waiting]"); 201 | } 202 | tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await; 203 | } 204 | } 205 | 206 | // Write request 207 | if self.trace_frames { 208 | trace!("Writing request"); 209 | } 210 | port.write_all(request).map_err(|e| TransportError::Io { 211 | operation: IoOperation::Write, 212 | details: "Failed to write request".to_string(), 213 | source: e, 214 | })?; 215 | 216 | port.flush().map_err(|e| TransportError::Io { 217 | operation: IoOperation::Flush, 218 | details: "Failed to flush write buffer".to_string(), 219 | source: e, 220 | })?; 221 | 222 | if self.config.rts_type != RtsType::None { 223 | if self.trace_frames { 224 | trace!("RTS -> RX mode"); 225 | } 226 | 227 | self.set_rts( 228 | self.config.rts_type.to_signal_level(false), 229 | self.trace_frames, 230 | )?; 231 | } 232 | 233 | if self.config.flush_after_write { 234 | if self.trace_frames { 235 | trace!("RTS -> TX mode [flushing]"); 236 | } 237 | self.tc_flush()?; 238 | } 239 | 240 | if self.config.rts_type != RtsType::None && self.config.rts_delay_us > 0 { 241 | if self.trace_frames { 242 | trace!("RTS -> RX mode [waiting]"); 243 | } 244 | tokio::time::sleep(Duration::from_micros(self.config.rts_delay_us)).await; 245 | } 246 | 247 | // Read response 248 | if self.trace_frames { 249 | trace!("Reading response (expecting {} bytes)", expected_size); 250 | } 251 | 252 | const MAX_TIMEOUTS: u8 = 3; 253 | let mut total_bytes = 0; 254 | let mut consecutive_timeouts = 0; 255 | let inter_byte_timeout = Duration::from_millis(100); 256 | let mut last_read_time = tokio::time::Instant::now(); 257 | 258 | while total_bytes < expected_size { 259 | match port.read(&mut response[total_bytes..]) { 260 | Ok(0) => { 261 | if total_bytes > 0 { 262 | let elapsed = last_read_time.elapsed(); 263 | if elapsed >= inter_byte_timeout { 264 | trace!("Inter-byte timeout reached with {} bytes", total_bytes); 265 | break; 266 | } 267 | } 268 | tokio::task::yield_now().await; 269 | } 270 | Ok(n) => { 271 | if self.trace_frames { 272 | trace!( 273 | "Read {} bytes: {:02X?}", 274 | n, 275 | &response[total_bytes..total_bytes + n] 276 | ); 277 | } 278 | total_bytes += n; 279 | last_read_time = tokio::time::Instant::now(); 280 | consecutive_timeouts = 0; 281 | 282 | if total_bytes >= expected_size { 283 | if self.trace_frames { 284 | trace!("Received complete response"); 285 | } 286 | break; 287 | } 288 | } 289 | Err(e) if e.kind() == std::io::ErrorKind::TimedOut => { 290 | if total_bytes > 0 { 291 | let elapsed = last_read_time.elapsed(); 292 | if elapsed >= inter_byte_timeout { 293 | trace!("Inter-byte timeout reached after timeout"); 294 | break; 295 | } 296 | } 297 | consecutive_timeouts += 1; 298 | if consecutive_timeouts >= MAX_TIMEOUTS { 299 | if total_bytes == 0 { 300 | return Err(TransportError::NoResponse { 301 | attempts: consecutive_timeouts, 302 | elapsed: transaction_start.elapsed(), 303 | }); 304 | } 305 | trace!("Max timeouts reached with {} bytes", total_bytes); 306 | break; 307 | } 308 | tokio::task::yield_now().await; 309 | } 310 | Err(e) => { 311 | return Err(TransportError::Io { 312 | operation: IoOperation::Read, 313 | details: "Failed to read response".to_string(), 314 | source: e, 315 | }); 316 | } 317 | } 318 | } 319 | 320 | if total_bytes == 0 { 321 | return Err(TransportError::NoResponse { 322 | attempts: consecutive_timeouts, 323 | elapsed: transaction_start.elapsed(), 324 | }); 325 | } 326 | 327 | // Verify minimum response size 328 | if total_bytes < 3 { 329 | return Err(TransportError::Io { 330 | operation: IoOperation::Read, 331 | details: format!("Response too short: {} bytes", total_bytes), 332 | source: std::io::Error::new( 333 | std::io::ErrorKind::InvalidData, 334 | "Response too short", 335 | ), 336 | }); 337 | } 338 | 339 | if self.trace_frames { 340 | trace!( 341 | "RX: {} bytes: {:02X?}", 342 | total_bytes, 343 | &response[..total_bytes], 344 | ); 345 | } 346 | 347 | Ok(total_bytes) 348 | }) 349 | .await 350 | .map_err(|elapsed| TransportError::Timeout { 351 | elapsed: transaction_start.elapsed(), 352 | limit: self.config.transaction_timeout, 353 | source: elapsed, 354 | })?; 355 | 356 | Ok(result?) 357 | } 358 | } 359 | -------------------------------------------------------------------------------- /src/modbus.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use tracing::{debug, trace}; 4 | 5 | use crate::{FrameErrorKind, RelayError, RtuTransport, errors::FrameError}; 6 | 7 | /// Calculates the CRC16 checksum for Modbus RTU communication using a lookup table for high performance. 8 | /// 9 | /// This function computes the CRC16-Modbus checksum for the provided data frame. 10 | /// It uses a precomputed lookup table to optimize performance by eliminating 11 | /// bitwise calculations within the inner loop. 12 | /// 13 | /// # Arguments 14 | /// 15 | /// * `data` - A slice of bytes representing the data frame for which the CRC is to be computed. 16 | /// 17 | /// # Returns 18 | /// 19 | /// The computed 16-bit CRC as a `u16` value. 20 | fn calc_crc16(data: &[u8]) -> u16 { 21 | // Precomputed CRC16 lookup table for polynomial 0xA001 (Modbus standard) 22 | const CRC16_TABLE: [u16; 256] = [ 23 | 0x0000, 0xC0C1, 0xC181, 0x0140, 0xC301, 0x03C0, 0x0280, 0xC241, 0xC601, 0x06C0, 0x0780, 24 | 0xC741, 0x0500, 0xC5C1, 0xC481, 0x0440, 0xCC01, 0x0CC0, 0x0D80, 0xCD41, 0x0F00, 0xCFC1, 25 | 0xCE81, 0x0E40, 0x0A00, 0xCAC1, 0xCB81, 0x0B40, 0xC901, 0x09C0, 0x0880, 0xC841, 0xD801, 26 | 0x18C0, 0x1980, 0xD941, 0x1B00, 0xDBC1, 0xDA81, 0x1A40, 0x1E00, 0xDEC1, 0xDF81, 0x1F40, 27 | 0xDD01, 0x1DC0, 0x1C80, 0xDC41, 0x1400, 0xD4C1, 0xD581, 0x1540, 0xD701, 0x17C0, 0x1680, 28 | 0xD641, 0xD201, 0x12C0, 0x1380, 0xD341, 0x1100, 0xD1C1, 0xD081, 0x1040, 0xF001, 0x30C0, 29 | 0x3180, 0xF141, 0x3300, 0xF3C1, 0xF281, 0x3240, 0x3600, 0xF6C1, 0xF781, 0x3740, 0xF501, 30 | 0x35C0, 0x3480, 0xF441, 0x3C00, 0xFCC1, 0xFD81, 0x3D40, 0xFF01, 0x3FC0, 0x3E80, 0xFE41, 31 | 0xFA01, 0x3AC0, 0x3B80, 0xFB41, 0x3900, 0xF9C1, 0xF881, 0x3840, 0x2800, 0xE8C1, 0xE981, 32 | 0x2940, 0xEB01, 0x2BC0, 0x2A80, 0xEA41, 0xEE01, 0x2EC0, 0x2F80, 0xEF41, 0x2D00, 0xEDC1, 33 | 0xEC81, 0x2C40, 0xE401, 0x24C0, 0x2580, 0xE541, 0x2700, 0xE7C1, 0xE681, 0x2640, 0x2200, 34 | 0xE2C1, 0xE381, 0x2340, 0xE101, 0x21C0, 0x2080, 0xE041, 0xA001, 0x60C0, 0x6180, 0xA141, 35 | 0x6300, 0xA3C1, 0xA281, 0x6240, 0x6600, 0xA6C1, 0xA781, 0x6740, 0xA501, 0x65C0, 0x6480, 36 | 0xA441, 0x6C00, 0xACC1, 0xAD81, 0x6D40, 0xAF01, 0x6FC0, 0x6E80, 0xAE41, 0xAA01, 0x6AC0, 37 | 0x6B80, 0xAB41, 0x6900, 0xA9C1, 0xA881, 0x6840, 0x7800, 0xB8C1, 0xB981, 0x7940, 0xBB01, 38 | 0x7BC0, 0x7A80, 0xBA41, 0xBE01, 0x7EC0, 0x7F80, 0xBF41, 0x7D00, 0xBDC1, 0xBC81, 0x7C40, 39 | 0xB401, 0x74C0, 0x7580, 0xB541, 0x7700, 0xB7C1, 0xB681, 0x7640, 0x7200, 0xB2C1, 0xB381, 40 | 0x7340, 0xB101, 0x71C0, 0x7080, 0xB041, 0x5000, 0x90C1, 0x9181, 0x5140, 0x9301, 0x53C0, 41 | 0x5280, 0x9241, 0x9601, 0x56C0, 0x5780, 0x9741, 0x5500, 0x95C1, 0x9481, 0x5440, 0x9C01, 42 | 0x5CC0, 0x5D80, 0x9D41, 0x5F00, 0x9FC1, 0x9E81, 0x5E40, 0x5A00, 0x9AC1, 0x9B81, 0x5B40, 43 | 0x9901, 0x59C0, 0x5880, 0x9841, 0x8801, 0x48C0, 0x4980, 0x8941, 0x4B00, 0x8BC1, 0x8A81, 44 | 0x4A40, 0x4E00, 0x8EC1, 0x8F81, 0x4F40, 0x8D01, 0x4DC0, 0x4C80, 0x8C41, 0x4400, 0x84C1, 45 | 0x8581, 0x4540, 0x8701, 0x47C0, 0x4680, 0x8641, 0x8201, 0x42C0, 0x4380, 0x8341, 0x4100, 46 | 0x81C1, 0x8081, 0x4040, 47 | ]; 48 | 49 | let mut crc: u16 = 0xFFFF; // Initialize CRC to 0xFFFF as per Modbus standard 50 | 51 | for &byte in data { 52 | // XOR the lower byte of the CRC with the current byte and find the lookup table index 53 | let index = ((crc ^ byte as u16) & 0x00FF) as usize; 54 | // Update the CRC by shifting right and XORing with the table value 55 | crc = (crc >> 8) ^ CRC16_TABLE[index]; 56 | } 57 | 58 | crc 59 | } 60 | 61 | /// Estimates the expected size of a Modbus RTU response frame based on the function code and quantity. 62 | /// 63 | /// # Arguments 64 | /// 65 | /// * `function` - The Modbus function code. 66 | /// * `quantity` - The number of coils or registers involved. 67 | /// 68 | /// # Returns 69 | /// 70 | /// The estimated size of the response frame in bytes. 71 | pub fn guess_response_size(function: u8, quantity: u16) -> usize { 72 | match function { 73 | 0x01 | 0x02 => { 74 | // Read Coils / Read Discrete Inputs 75 | // Each coil status is one bit; calculate the number of data bytes required 76 | let data_bytes = (quantity as usize).div_ceil(8); // Round up to the nearest whole byte 77 | // Response size: Address(1) + Function(1) + Byte Count(1) + Data + CRC(2) 78 | 1 + 1 + 1 + data_bytes + 2 79 | } 80 | 0x03 | 0x04 => { 81 | // Read Holding Registers / Read Input Registers 82 | // Each register is two bytes 83 | let data_bytes = (quantity as usize) * 2; 84 | // Response size: Address(1) + Function(1) + Byte Count(1) + Data + CRC(2) 85 | 1 + 1 + 1 + data_bytes + 2 86 | } 87 | 0x05 | 0x06 => { 88 | // Write Single Coil / Write Single Register 89 | // Response size: Address(1) + Function(1) + Address(2) + Value(2) + CRC(2) 90 | 1 + 1 + 2 + 2 + 2 91 | } 92 | 0x0F | 0x10 => { 93 | // Write Multiple Coils / Write Multiple Registers 94 | // Response size: Address(1) + Function(1) + Address(2) + Quantity(2) + CRC(2) 95 | 1 + 1 + 2 + 2 + 2 96 | } 97 | _ => { 98 | // Default maximum size for unknown function codes 99 | 256 100 | } 101 | } 102 | } 103 | 104 | /// Extracts a 16-bit unsigned integer from a Modbus RTU request frame starting at the specified index. 105 | /// 106 | /// This function attempts to retrieve two consecutive bytes from the provided request slice, 107 | /// starting at the given index, and converts them into a `u16` value using big-endian byte order. 108 | /// If the request slice is too short to contain the required bytes, it returns a `RelayError` 109 | /// indicating an invalid frame format. 110 | /// 111 | /// # Arguments 112 | /// 113 | /// * `request` - A slice of bytes representing the Modbus RTU request frame. 114 | /// * `start` - The starting index within the request slice from which to extract the `u16` value. 115 | /// 116 | /// # Returns 117 | /// 118 | /// A `Result` containing the extracted `u16` value if successful, or a `RelayError` if the request 119 | /// slice is too short. 120 | /// 121 | /// # Errors 122 | /// 123 | /// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain 124 | /// enough bytes to extract a `u16` value starting at the specified index. 125 | fn get_u16_from_request(request: &[u8], start: usize) -> Result { 126 | request 127 | .get(start..start + 2) 128 | .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) 129 | .ok_or_else(|| { 130 | RelayError::frame( 131 | FrameErrorKind::InvalidFormat, 132 | "Request too short for register quantity".to_string(), 133 | Some(request.to_vec()), 134 | ) 135 | }) 136 | } 137 | 138 | /// Extracts the quantity of coils or registers from a Modbus RTU request frame based on the function code. 139 | /// 140 | /// This function determines the quantity of coils or registers involved in a Modbus RTU request 141 | /// by examining the function code and extracting the appropriate bytes from the request frame. 142 | /// For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10), it extracts a 16-bit 143 | /// unsigned integer from bytes 4 and 5 of the request frame. For write single functions (0x05, 0x06), 144 | /// it returns a fixed quantity of 1. For other function codes, it defaults to a quantity of 1. 145 | /// 146 | /// # Arguments 147 | /// 148 | /// * `function_code` - The Modbus function code. 149 | /// * `request` - A slice of bytes representing the Modbus RTU request frame. 150 | /// 151 | /// # Returns 152 | /// 153 | /// A `Result` containing the extracted quantity as a `u16` value if successful, or a `RelayError` if the request 154 | /// slice is too short or the function code is invalid. 155 | /// 156 | /// # Errors 157 | /// 158 | /// Returns a `RelayError` with `FrameErrorKind::InvalidFormat` if the request slice does not contain 159 | /// enough bytes to extract the quantity for the specified function code. 160 | pub fn get_quantity(function_code: u8, request: &[u8]) -> Result { 161 | match function_code { 162 | // For read functions (0x01 to 0x04) and write multiple functions (0x0F, 0x10), 163 | // extract the quantity from bytes 4 and 5 of the request frame. 164 | 0x01..=0x04 | 0x0F | 0x10 => get_u16_from_request(request, 4), 165 | 166 | // For write single functions (0x05, 0x06), the quantity is always 1. 167 | 0x05 | 0x06 => Ok(1), 168 | 169 | // For other function codes, default the quantity to 1. 170 | _ => Ok(1), 171 | } 172 | } 173 | 174 | pub struct ModbusProcessor { 175 | transport: Arc, 176 | } 177 | 178 | impl ModbusProcessor { 179 | pub fn new(transport: Arc) -> Self { 180 | Self { transport } 181 | } 182 | 183 | /// Processes a Modbus TCP request by converting it to Modbus RTU, sending it over the transport, 184 | /// and then converting the RTU response back to Modbus TCP format. 185 | /// 186 | /// # Arguments 187 | /// 188 | /// * `transaction_id` - The Modbus TCP transaction ID. 189 | /// * `unit_id` - The Modbus unit ID (slave address). 190 | /// * `pdu` - The Protocol Data Unit from the Modbus TCP request. 191 | /// 192 | /// # Returns 193 | /// 194 | /// A `Result` containing the Modbus TCP response as a vector of bytes, or a `RelayError`. 195 | pub async fn process_request( 196 | &self, 197 | transaction_id: [u8; 2], 198 | unit_id: u8, 199 | pdu: &[u8], 200 | trace_frames: bool, 201 | ) -> Result, RelayError> { 202 | // Build RTU request frame: [Unit ID][PDU][CRC16] 203 | let mut rtu_request = Vec::with_capacity(1 + pdu.len() + 2); // Unit ID + PDU + CRC16 204 | rtu_request.push(unit_id); 205 | rtu_request.extend_from_slice(pdu); 206 | 207 | // Calculate CRC16 checksum and append to the request 208 | let crc = calc_crc16(&rtu_request); 209 | rtu_request.extend_from_slice(&crc.to_le_bytes()); // Append CRC16 in little-endian 210 | 211 | if trace_frames { 212 | trace!( 213 | "Sending RTU request: unit_id=0x{:02X}, function=0x{:02X}, data={:02X?}, crc=0x{:04X}", 214 | unit_id, 215 | pdu.first().copied().unwrap_or(0), 216 | &pdu[1..], 217 | crc 218 | ); 219 | } 220 | 221 | // Estimate the expected RTU response size 222 | let function_code = pdu.first().copied().unwrap_or(0); 223 | let quantity = get_quantity(function_code, &rtu_request)?; 224 | 225 | let expected_response_size = guess_response_size(function_code, quantity); 226 | 227 | // Allocate buffer for RTU response 228 | let mut rtu_response = vec![0u8; expected_response_size]; 229 | 230 | // Execute RTU transaction 231 | let rtu_len = match self 232 | .transport 233 | .transaction(&rtu_request, &mut rtu_response) 234 | .await 235 | { 236 | Ok(len) => { 237 | if len < 5 { 238 | // Minimum RTU response size: Unit ID(1) + Function(1) + Data(1) + CRC(2) 239 | return Err(RelayError::frame( 240 | FrameErrorKind::TooShort, 241 | format!("RTU response too short: {} bytes", len), 242 | Some(rtu_response[..len].to_vec()), 243 | )); 244 | } 245 | len 246 | } 247 | Err(e) => { 248 | debug!("Transport transaction error: {:?}", e); 249 | 250 | // Prepare Modbus exception response with exception code 0x0B (Gateway Path Unavailable) 251 | let exception_code = 0x0B; 252 | let mut exception_response = Vec::with_capacity(9); 253 | exception_response.extend_from_slice(&transaction_id); 254 | exception_response.extend_from_slice(&[0x00, 0x00]); // Protocol ID 255 | exception_response.extend_from_slice(&[0x00, 0x03]); // Length (Unit ID + Function + Exception Code) 256 | exception_response.push(unit_id); 257 | exception_response.push(function_code | 0x80); // Exception function code 258 | exception_response.push(exception_code); 259 | 260 | return Ok(exception_response); 261 | } 262 | }; 263 | 264 | // Truncate the buffer to the actual response length 265 | rtu_response.truncate(rtu_len); 266 | 267 | // Verify the CRC16 checksum of the RTU response 268 | let expected_crc = calc_crc16(&rtu_response[..rtu_len - 2]); 269 | let received_crc = 270 | u16::from_le_bytes([rtu_response[rtu_len - 2], rtu_response[rtu_len - 1]]); 271 | if expected_crc != received_crc { 272 | return Err(RelayError::Frame(FrameError::Crc { 273 | calculated: expected_crc, 274 | received: received_crc, 275 | frame_hex: hex::encode(&rtu_response[..rtu_len - 2]), 276 | })); 277 | } 278 | 279 | // Remove CRC from RTU response 280 | rtu_response.truncate(rtu_len - 2); 281 | 282 | // Verify that the unit ID in the response matches 283 | if rtu_response[0] != unit_id { 284 | return Err(RelayError::frame( 285 | FrameErrorKind::InvalidUnitId, 286 | format!( 287 | "Unexpected unit ID in RTU response: expected=0x{:02X}, received=0x{:02X}", 288 | unit_id, rtu_response[0] 289 | ), 290 | Some(rtu_response.clone()), 291 | )); 292 | } 293 | 294 | // Convert RTU response to Modbus TCP response 295 | let tcp_length = rtu_response.len() as u16; // Length of Unit ID + PDU 296 | let mut tcp_response = Vec::with_capacity(7 + rtu_response.len()); // MBAP Header(7) + PDU 297 | tcp_response.extend_from_slice(&transaction_id); // Transaction ID 298 | tcp_response.extend_from_slice(&[0x00, 0x00]); // Protocol ID 299 | tcp_response.extend_from_slice(&tcp_length.to_be_bytes()); // Length field 300 | tcp_response.extend_from_slice(&rtu_response); // Unit ID + PDU 301 | 302 | Ok(tcp_response) 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /src/config/relay.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use config::{Config as ConfigBuilder, ConfigError, Environment, File, FileFormat}; 6 | 7 | use super::{ConnectionConfig, HttpConfig, LoggingConfig, RtuConfig, TcpConfig}; 8 | 9 | /// Main application configuration 10 | #[derive(Default, Debug, Clone, Serialize, Deserialize)] 11 | #[serde(deny_unknown_fields)] 12 | pub struct Config { 13 | /// TCP server configuration 14 | pub tcp: TcpConfig, 15 | 16 | /// RTU client configuration 17 | pub rtu: RtuConfig, 18 | 19 | /// HTTP API configuration 20 | pub http: HttpConfig, 21 | 22 | /// Logging configuration 23 | pub logging: LoggingConfig, 24 | 25 | /// Connection management configuration 26 | pub connection: ConnectionConfig, 27 | } 28 | 29 | impl Config { 30 | /// Default configuration directory 31 | pub const CONFIG_DIR: &'static str = "config"; 32 | 33 | /// Environment variable prefix 34 | const ENV_PREFIX: &'static str = "MODBUS_RELAY"; 35 | 36 | /// Build configuration using the following priority (highest to lowest): 37 | /// 1. Environment variables (MODBUS_RELAY_*) 38 | /// 2. Local configuration file (config/local.yaml) 39 | /// 3. Environment specific file (config/{env}.yaml) 40 | /// 4. Default configuration (config/default.yaml) 41 | /// 5. Built-in defaults 42 | pub fn new() -> Result { 43 | let environment = std::env::var("RUN_MODE").unwrap_or_else(|_| "development".into()); 44 | 45 | // Start with built-in defaults 46 | let defaults = Config::default(); 47 | 48 | let mut builder = ConfigBuilder::builder(); 49 | 50 | // Set defaults for each field manually 51 | builder = builder 52 | // TCP configuration 53 | .set_default("tcp.bind_addr", defaults.tcp.bind_addr)? 54 | .set_default("tcp.bind_port", defaults.tcp.bind_port)? 55 | // RTU configuration 56 | .set_default("rtu.device", defaults.rtu.device)? 57 | .set_default("rtu.baud_rate", defaults.rtu.baud_rate)? 58 | .set_default("rtu.data_bits", defaults.rtu.data_bits.to_string())? 59 | .set_default("rtu.parity", defaults.rtu.parity.to_string())? 60 | .set_default("rtu.stop_bits", defaults.rtu.stop_bits.to_string())? 61 | .set_default("rtu.flush_after_write", defaults.rtu.flush_after_write)? 62 | .set_default("rtu.rts_type", defaults.rtu.rts_type.to_string())? 63 | .set_default("rtu.rts_delay_us", defaults.rtu.rts_delay_us)? 64 | .set_default( 65 | "rtu.transaction_timeout", 66 | format!("{}s", defaults.rtu.transaction_timeout.as_secs()), 67 | )? 68 | .set_default( 69 | "rtu.serial_timeout", 70 | format!("{}s", defaults.rtu.serial_timeout.as_secs()), 71 | )? 72 | .set_default("rtu.max_frame_size", defaults.rtu.max_frame_size)? 73 | // HTTP configuration 74 | .set_default("http.enabled", defaults.http.enabled)? 75 | .set_default("http.bind_addr", defaults.http.bind_addr)? 76 | .set_default("http.bind_port", defaults.http.bind_port)? 77 | .set_default("http.metrics_enabled", defaults.http.metrics_enabled)? 78 | // Logging configuration 79 | .set_default("logging.log_dir", defaults.logging.log_dir)? 80 | .set_default("logging.trace_frames", defaults.logging.trace_frames)? 81 | .set_default("logging.level", defaults.logging.level)? 82 | .set_default("logging.format", defaults.logging.format)? 83 | .set_default( 84 | "logging.include_location", 85 | defaults.logging.include_location, 86 | )? 87 | .set_default("logging.thread_ids", defaults.logging.thread_ids)? 88 | .set_default("logging.thread_names", defaults.logging.thread_names)? 89 | // Connection configuration 90 | .set_default( 91 | "connection.max_connections", 92 | defaults.connection.max_connections, 93 | )? 94 | .set_default( 95 | "connection.idle_timeout", 96 | format!("{}s", defaults.connection.idle_timeout.as_secs()), 97 | )? 98 | .set_default( 99 | "connection.connect_timeout", 100 | format!("{}s", defaults.connection.connect_timeout.as_secs()), 101 | )? 102 | .set_default( 103 | "connection.per_ip_limits", 104 | defaults.connection.per_ip_limits, 105 | )? 106 | // Connection backoff configuration 107 | .set_default( 108 | "connection.backoff.initial_interval", 109 | format!( 110 | "{}s", 111 | defaults.connection.backoff.initial_interval.as_secs() 112 | ), 113 | )? 114 | .set_default( 115 | "connection.backoff.max_interval", 116 | format!("{}s", defaults.connection.backoff.max_interval.as_secs()), 117 | )? 118 | .set_default( 119 | "connection.backoff.multiplier", 120 | defaults.connection.backoff.multiplier, 121 | )? 122 | .set_default( 123 | "connection.backoff.max_retries", 124 | defaults.connection.backoff.max_retries, 125 | )?; 126 | 127 | let config = builder 128 | // Load default config file 129 | .add_source(File::new( 130 | &format!("{}/default", Self::CONFIG_DIR), 131 | FileFormat::Yaml, 132 | )) 133 | // Load environment specific config 134 | .add_source( 135 | File::new( 136 | &format!("{}/{}", Self::CONFIG_DIR, environment), 137 | FileFormat::Yaml, 138 | ) 139 | .required(false), 140 | ) 141 | // Load local overrides 142 | .add_source( 143 | File::new(&format!("{}/local", Self::CONFIG_DIR), FileFormat::Yaml).required(false), 144 | ) 145 | // Add environment variables 146 | .add_source( 147 | Environment::with_prefix(Self::ENV_PREFIX) 148 | .prefix_separator("_") 149 | .separator("__") 150 | .try_parsing(true), 151 | ) 152 | .build()?; 153 | 154 | // Deserialize and validate 155 | let config = config.try_deserialize()?; 156 | Self::validate(&config)?; 157 | 158 | Ok(config) 159 | } 160 | 161 | /// Load configuration from a specific file 162 | pub fn from_file(path: PathBuf) -> Result { 163 | let config = ConfigBuilder::builder() 164 | // Load the specified config file 165 | .add_source(File::from(path)) 166 | // Add env vars as overrides 167 | .add_source( 168 | Environment::with_prefix(Self::ENV_PREFIX) 169 | .separator("_") 170 | .try_parsing(true), 171 | ) 172 | .build()?; 173 | 174 | let config = config.try_deserialize()?; 175 | Self::validate(&config)?; 176 | 177 | Ok(config) 178 | } 179 | 180 | /// Validate configuration 181 | pub fn validate(config: &Self) -> Result<(), ConfigError> { 182 | // Helper to convert validation errors 183 | fn validation_error(msg: &str) -> ConfigError { 184 | ConfigError::Message(msg.to_string()) 185 | } 186 | 187 | // Validate TCP configuration 188 | if config.tcp.bind_addr.is_empty() { 189 | return Err(validation_error("TCP bind address must not be empty")); 190 | } 191 | if config.tcp.bind_port == 0 { 192 | return Err(validation_error("TCP port must be non-zero")); 193 | } 194 | 195 | // Validate RTU configuration 196 | if config.rtu.device.is_empty() { 197 | return Err(validation_error("RTU device must not be empty")); 198 | } 199 | if config.rtu.baud_rate == 0 { 200 | return Err(validation_error("RTU baud rate must be non-zero")); 201 | } 202 | 203 | // Validate connection configuration 204 | if config.rtu.transaction_timeout.is_zero() { 205 | return Err(validation_error("Transaction timeout must be non-zero")); 206 | } 207 | if config.rtu.serial_timeout.is_zero() { 208 | return Err(validation_error("Serial timeout must be non-zero")); 209 | } 210 | if config.rtu.max_frame_size == 0 { 211 | return Err(validation_error("Max frame size must be non-zero")); 212 | } 213 | 214 | // Validate log level 215 | match config.logging.level.to_lowercase().as_str() { 216 | "error" | "warn" | "info" | "debug" | "trace" => {} 217 | _ => return Err(validation_error("Invalid log level")), 218 | } 219 | 220 | // Validate log format 221 | match config.logging.format.to_lowercase().as_str() { 222 | "pretty" | "json" => {} 223 | _ => return Err(validation_error("Invalid log format")), 224 | } 225 | 226 | // Validate connection configuration 227 | if config.connection.max_connections == 0 { 228 | return Err(validation_error("Maximum connections must be non-zero")); 229 | } 230 | if config.connection.idle_timeout.is_zero() { 231 | return Err(validation_error("Idle timeout must be non-zero")); 232 | } 233 | if config.connection.connect_timeout.is_zero() { 234 | return Err(validation_error("Connect timeout must be non-zero")); 235 | } 236 | if let Some(limit) = config.connection.per_ip_limits { 237 | if limit == 0 { 238 | return Err(validation_error("Per IP connection limit must be non-zero")); 239 | } 240 | if limit > config.connection.max_connections { 241 | return Err(validation_error( 242 | "Per IP connection limit cannot exceed maximum connections", 243 | )); 244 | } 245 | } 246 | // Validate backoff configuration 247 | if config.connection.backoff.initial_interval.is_zero() { 248 | return Err(validation_error( 249 | "Backoff initial interval must be non-zero", 250 | )); 251 | } 252 | if config.connection.backoff.max_interval.is_zero() { 253 | return Err(validation_error("Backoff max interval must be non-zero")); 254 | } 255 | if config.connection.backoff.multiplier <= 0.0 { 256 | return Err(validation_error("Backoff multiplier must be positive")); 257 | } 258 | if config.connection.backoff.max_retries == 0 { 259 | return Err(validation_error("Backoff max retries must be non-zero")); 260 | } 261 | 262 | Ok(()) 263 | } 264 | } 265 | 266 | #[cfg(test)] 267 | mod tests { 268 | use crate::{DataBits, Parity, RtsType, StopBits}; 269 | 270 | use super::*; 271 | use std::{fs, time::Duration}; 272 | use tempfile::tempdir; 273 | 274 | #[test] 275 | #[serial_test::serial] 276 | fn test_default_config() { 277 | let config = Config::new().unwrap(); 278 | assert_eq!(config.tcp.bind_port, 502); 279 | assert_eq!(config.tcp.bind_addr, "127.0.0.1"); 280 | } 281 | 282 | #[test] 283 | #[serial_test::serial] 284 | fn test_env_override() { 285 | unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "5000") }; 286 | let config = Config::new().unwrap(); 287 | assert_eq!(config.tcp.bind_port, 5000); 288 | unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") }; 289 | } 290 | 291 | #[test] 292 | #[serial_test::serial] 293 | fn test_file_config() { 294 | let dir = tempdir().unwrap(); 295 | let config_path = dir.path().join("config.yaml"); 296 | 297 | fs::write( 298 | &config_path, 299 | r#" 300 | tcp: 301 | bind_port: 9000 302 | bind_addr: "192.168.1.100" 303 | keep_alive: "60s" 304 | rtu: 305 | device: "/dev/ttyAMA0" 306 | baud_rate: 9600 307 | data_bits: 8 308 | parity: "none" 309 | stop_bits: "one" 310 | flush_after_write: true 311 | rts_type: "down" 312 | rts_delay_us: 3500 313 | transaction_timeout: "5s" 314 | serial_timeout: "1s" 315 | max_frame_size: 256 316 | http: 317 | enabled: false 318 | bind_addr: "192.168.1.100" 319 | bind_port: 9080 320 | metrics_enabled: false 321 | logging: 322 | log_dir: "logs" 323 | trace_frames: false 324 | level: "trace" 325 | format: "pretty" 326 | include_location: false 327 | thread_ids: false 328 | thread_names: true 329 | connection: 330 | max_connections: 100 331 | idle_timeout: "60s" 332 | error_timeout: "300s" 333 | connect_timeout: "5s" 334 | per_ip_limits: 10 335 | backoff: 336 | # Initial wait time 337 | initial_interval: "100ms" 338 | # Maximum wait time 339 | max_interval: "30s" 340 | # Multiplier for each subsequent attempt 341 | multiplier: 2.0 342 | # Maximum number of attempts 343 | max_retries: 5 344 | "#, 345 | ) 346 | .unwrap(); 347 | 348 | let config = Config::from_file(config_path).unwrap(); 349 | assert_eq!(config.tcp.bind_port, 9000); 350 | assert_eq!(config.tcp.bind_addr, "192.168.1.100"); 351 | assert_eq!(config.tcp.keep_alive, Duration::from_secs(60)); 352 | assert_eq!(config.rtu.device, "/dev/ttyAMA0"); 353 | assert_eq!(config.rtu.baud_rate, 9600); 354 | assert_eq!(config.rtu.data_bits, DataBits::new(8).unwrap()); 355 | assert_eq!(config.rtu.parity, Parity::None); 356 | assert_eq!(config.rtu.stop_bits, StopBits::One); 357 | assert!(config.rtu.flush_after_write); 358 | assert_eq!(config.rtu.rts_type, RtsType::Down); 359 | assert_eq!(config.rtu.rts_delay_us, 3500); 360 | assert_eq!(config.rtu.transaction_timeout, Duration::from_secs(5)); 361 | assert_eq!(config.rtu.serial_timeout, Duration::from_secs(1)); 362 | assert_eq!(config.rtu.max_frame_size, 256); 363 | assert!(!config.http.enabled); 364 | assert_eq!(config.http.bind_addr, "192.168.1.100"); 365 | assert_eq!(config.http.bind_port, 9080); 366 | assert!(!config.http.metrics_enabled); 367 | assert_eq!(config.logging.log_dir, "logs"); 368 | assert!(!config.logging.trace_frames); 369 | assert_eq!(config.logging.level, "trace"); 370 | assert_eq!(config.logging.format, "pretty"); 371 | assert!(!config.logging.include_location); 372 | assert!(!config.logging.thread_ids); 373 | assert!(config.logging.thread_names); 374 | assert_eq!(config.connection.max_connections, 100); 375 | assert_eq!(config.connection.idle_timeout, Duration::from_secs(60)); 376 | assert_eq!(config.connection.error_timeout, Duration::from_secs(300)); 377 | assert_eq!(config.connection.connect_timeout, Duration::from_secs(5)); 378 | assert_eq!(config.connection.per_ip_limits, Some(10)); 379 | assert_eq!( 380 | config.connection.backoff.initial_interval, 381 | Duration::from_millis(100) 382 | ); 383 | assert_eq!( 384 | config.connection.backoff.max_interval, 385 | Duration::from_secs(30) 386 | ); 387 | assert_eq!(config.connection.backoff.multiplier, 2.0); 388 | assert_eq!(config.connection.backoff.max_retries, 5); 389 | } 390 | 391 | #[test] 392 | #[serial_test::serial] 393 | fn test_validation() { 394 | unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "0") }; 395 | assert!(Config::new().is_err()); 396 | unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") }; 397 | } 398 | } 399 | -------------------------------------------------------------------------------- /src/modbus_relay.rs: -------------------------------------------------------------------------------- 1 | use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration, time::Instant}; 2 | 3 | use tokio::{ 4 | io::{AsyncReadExt, AsyncWriteExt}, 5 | net::{TcpListener, TcpStream}, 6 | sync::{Mutex, broadcast, mpsc}, 7 | task::{JoinError, JoinHandle}, 8 | time::{sleep, timeout}, 9 | }; 10 | use tracing::{debug, error, info, trace, warn}; 11 | 12 | use crate::{ 13 | ConnectionManager, IoOperation, ModbusProcessor, RelayConfig, StatsConfig, StatsManager, 14 | connection::StatEvent, 15 | errors::{ 16 | ClientErrorKind, ConnectionError, FrameErrorKind, ProtocolErrorKind, RelayError, 17 | TransportError, 18 | }, 19 | http_api::start_http_server, 20 | rtu_transport::RtuTransport, 21 | utils::generate_request_id, 22 | }; 23 | 24 | use socket2::{SockRef, TcpKeepalive}; 25 | 26 | pub struct ModbusRelay { 27 | config: RelayConfig, 28 | transport: Arc, 29 | connection_manager: Arc, 30 | stats_tx: mpsc::Sender, 31 | shutdown: broadcast::Sender<()>, 32 | main_shutdown: tokio::sync::watch::Sender, 33 | stats_manager_shutdown: tokio::sync::watch::Sender, 34 | tasks: Arc>>>, 35 | stats_manager_handle: Mutex>>>, 36 | } 37 | 38 | impl ModbusRelay { 39 | pub fn new(config: RelayConfig) -> Result { 40 | // Validate the config first 41 | RelayConfig::validate(&config)?; 42 | 43 | let transport = RtuTransport::new(&config.rtu, config.logging.trace_frames)?; 44 | 45 | // Create stats manager first 46 | let stats_config = StatsConfig { 47 | cleanup_interval: config.connection.idle_timeout, 48 | idle_timeout: config.connection.idle_timeout, 49 | error_timeout: config.connection.error_timeout, 50 | max_events_per_second: 10000, // TODO(aljen): Make configurable 51 | }; 52 | let (stats_manager, stats_tx) = StatsManager::new(stats_config); 53 | let stats_manager = Arc::new(Mutex::new(stats_manager)); 54 | 55 | // Initialize connection manager with stats sender 56 | let connection_manager = Arc::new(ConnectionManager::new( 57 | config.connection.clone(), 58 | stats_tx.clone(), 59 | )); 60 | 61 | let (shutdown_tx, _) = broadcast::channel(1); 62 | let (main_shutdown_tx, _) = tokio::sync::watch::channel(false); 63 | let (stats_manager_shutdown_tx, _) = tokio::sync::watch::channel(false); 64 | 65 | // Start stats manager but keep its handle separate from tasks vector 66 | let stats_manager_handle = tokio::spawn({ 67 | let stats_manager = Arc::clone(&stats_manager); 68 | let stats_manager_shutdown_tx = stats_manager_shutdown_tx.subscribe(); 69 | 70 | tokio::spawn(async move { 71 | let mut stats_manager = stats_manager.lock().await; 72 | 73 | stats_manager.run(stats_manager_shutdown_tx).await; 74 | }) 75 | }); 76 | 77 | Ok(Self { 78 | config, 79 | transport: Arc::new(transport), 80 | connection_manager, 81 | stats_tx, 82 | shutdown: shutdown_tx, 83 | main_shutdown: main_shutdown_tx, 84 | stats_manager_shutdown: stats_manager_shutdown_tx, 85 | tasks: Arc::new(Mutex::new(Vec::new())), 86 | stats_manager_handle: Mutex::new(Some(stats_manager_handle)), 87 | }) 88 | } 89 | 90 | fn spawn_task(&self, name: &str, future: F) 91 | where 92 | F: Future + Send + 'static, 93 | { 94 | let task = tokio::spawn(future); 95 | debug!("Spawned {} task: {:?}", name, task.id()); 96 | 97 | let _ = self.tasks.try_lock().map(|mut guard| guard.push(task)); 98 | } 99 | 100 | async fn configure_tcp_stream( 101 | socket: &TcpStream, 102 | keep_alive_duration: Duration, 103 | ) -> Result<(), RelayError> { 104 | // Configure TCP socket using SockRef 105 | let sock_ref = SockRef::from(&socket); 106 | 107 | // Enable TCP keepalive 108 | sock_ref.set_keepalive(true).map_err(|e| { 109 | RelayError::Transport(TransportError::Io { 110 | operation: IoOperation::Configure, 111 | details: "Failed to enable TCP keepalive".to_string(), 112 | source: e, 113 | }) 114 | })?; 115 | 116 | // Set TCP_NODELAY 117 | sock_ref.set_tcp_nodelay(true).map_err(|e| { 118 | RelayError::Transport(TransportError::Io { 119 | operation: IoOperation::Configure, 120 | details: "Failed to set TCP_NODELAY".to_string(), 121 | source: e, 122 | }) 123 | })?; 124 | 125 | #[cfg(any(target_os = "linux", target_os = "macos"))] 126 | { 127 | let mut ka = TcpKeepalive::new(); 128 | ka = ka.with_time(keep_alive_duration); 129 | ka = ka.with_interval(keep_alive_duration); 130 | 131 | sock_ref.set_tcp_keepalive(&ka).map_err(|e| { 132 | RelayError::Transport(TransportError::Io { 133 | operation: IoOperation::Configure, 134 | details: "Failed to set TCP keepalive parameters".to_string(), 135 | source: e, 136 | }) 137 | })?; 138 | } 139 | 140 | Ok(()) 141 | } 142 | 143 | pub async fn run(self: Arc) -> Result<(), RelayError> { 144 | // Start TCP server 145 | let tcp_server = { 146 | let transport = Arc::clone(&self.transport); 147 | let manager = Arc::clone(&self.connection_manager); 148 | let stats_tx = self.stats_tx.clone(); 149 | let mut rx = self.shutdown.subscribe(); 150 | let config = self.config.clone(); 151 | let keep_alive_duration = self.config.tcp.keep_alive; 152 | let trace_frames = self.config.logging.trace_frames; 153 | 154 | let shutdown_rx = self.shutdown.subscribe(); 155 | 156 | tokio::spawn(async move { 157 | let addr = format!("{}:{}", config.tcp.bind_addr, config.tcp.bind_port); 158 | let listener = TcpListener::bind(&addr).await.map_err(|e| { 159 | RelayError::Transport(TransportError::Io { 160 | operation: IoOperation::Listen, 161 | details: format!("Failed to bind TCP listener to {}", addr), 162 | source: e, 163 | }) 164 | })?; 165 | 166 | info!("MODBUS TCP server listening on {}", addr); 167 | 168 | loop { 169 | tokio::select! { 170 | accept_result = listener.accept() => { 171 | match accept_result { 172 | Ok((socket, peer)) => { 173 | let transport = Arc::clone(&transport); 174 | let manager = Arc::clone(&manager); 175 | let stats_tx = stats_tx.clone(); 176 | let shutdown_rx = shutdown_rx.resubscribe(); 177 | 178 | Self::configure_tcp_stream(&socket, keep_alive_duration) 179 | .await 180 | .map_err(|e| { 181 | error!("Failed to configure TCP stream: {}", e); 182 | }) 183 | .map(|_| { 184 | debug!( 185 | "TCP stream configured with keepalive: {:?}", 186 | keep_alive_duration 187 | ) 188 | }) 189 | .ok(); 190 | 191 | tokio::spawn(async move { 192 | if let Err(e) = handle_client( 193 | socket, 194 | peer, 195 | transport, 196 | manager, 197 | stats_tx, 198 | shutdown_rx, 199 | trace_frames, 200 | ) 201 | .await 202 | { 203 | error!("Client error: {}", e); 204 | } 205 | }); 206 | } 207 | Err(e) => { 208 | error!("Failed to accept connection: {}", e); 209 | } 210 | } 211 | } 212 | _ = rx.recv() => { 213 | info!("MODBUS TCP server shutting down"); 214 | break; 215 | } 216 | } 217 | } 218 | 219 | info!("MODBUS TCP server shutdown complete"); 220 | 221 | Ok::<_, RelayError>(()) 222 | }) 223 | }; 224 | 225 | self.spawn_task("tcp_server", async move { 226 | if let Err(e) = tcp_server.await { 227 | error!("TCP server task failed: {}", e); 228 | } 229 | }); 230 | 231 | // Start HTTP server if enabled 232 | if self.config.http.enabled { 233 | let http_server = start_http_server( 234 | self.config.http.bind_addr.clone(), 235 | self.config.http.bind_port, 236 | self.connection_manager.clone(), 237 | self.shutdown.subscribe(), 238 | ); 239 | 240 | self.spawn_task("http", async move { 241 | if let Err(e) = http_server.await { 242 | error!("HTTP server error: {}", e) 243 | } 244 | }); 245 | } 246 | 247 | // Start a task to clean up idle connections 248 | let manager = Arc::clone(&self.connection_manager); 249 | let mut shutdown_rx = self.shutdown.subscribe(); 250 | 251 | self.spawn_task("cleanup", async move { 252 | let mut interval = tokio::time::interval(Duration::from_secs(60)); 253 | 254 | loop { 255 | tokio::select! { 256 | _ = interval.tick() => { 257 | if let Err(e) = manager.cleanup_idle_connections().await { 258 | error!("Error during connection cleanup: {}", e); 259 | } 260 | } 261 | _ = shutdown_rx.recv() => { 262 | trace!("Cleanup task received shutdown signal"); 263 | break; 264 | } 265 | } 266 | } 267 | 268 | trace!("Cleanup task exited"); 269 | }); 270 | 271 | // Wait for shutdown signal 272 | let mut shutdown_rx = self.main_shutdown.subscribe(); 273 | 274 | tokio::select! { 275 | _ = shutdown_rx.changed() => { 276 | trace!("Main loop received shutdown signal"); 277 | } 278 | } 279 | 280 | trace!("Main loop exited"); 281 | 282 | Ok(()) 283 | } 284 | 285 | /// Graceful shutdown 286 | pub async fn shutdown(&self) -> Result<(), RelayError> { 287 | info!("Initiating graceful shutdown"); 288 | let timeout_duration = Duration::from_secs(5); 289 | 290 | // Send main shutdown signal 291 | let _ = self.main_shutdown.send(true); 292 | 293 | // 1. Log initial state 294 | let stats = self.connection_manager.get_stats().await?; 295 | trace!( 296 | "Current state: {} active connections, {} total requests", 297 | stats.active_connections, stats.total_requests 298 | ); 299 | 300 | // 2. Send shutdown signal to all tasks 301 | trace!("Sending shutdown signal to tasks"); 302 | self.shutdown.send(()).map_err(|e| { 303 | RelayError::Connection(ConnectionError::invalid_state(format!( 304 | "Failed to send shutdown signal: {}", 305 | e 306 | ))) 307 | })?; 308 | 309 | // 3. Wait for connections to close with timeout 310 | info!( 311 | "Waiting {}s for connections to close", 312 | timeout_duration.as_secs() 313 | ); 314 | let start = Instant::now(); 315 | while start.elapsed() < timeout_duration { 316 | if let Ok(stats) = self.connection_manager.get_stats().await { 317 | if stats.active_connections == 0 { 318 | info!("All connections closed"); 319 | break; 320 | } 321 | info!( 322 | "Waiting for {} connections to close", 323 | stats.active_connections 324 | ); 325 | } 326 | 327 | trace!("Sleeping for 100ms"); 328 | sleep(Duration::from_millis(100)).await; 329 | } 330 | 331 | // Check if we timed out 332 | if start.elapsed() >= timeout_duration { 333 | warn!("Timeout waiting for connections to close, forcing shutdown"); 334 | } 335 | 336 | // 4. Now we can safely close the serial port 337 | info!("Closing serial port"); 338 | if let Err(e) = self.transport.close().await { 339 | error!("Error closing serial port: {}", e); 340 | } 341 | 342 | // 5. Waiting for all tasks to complete 343 | trace!("Waiting for tasks to complete"); 344 | let tasks = { 345 | let mut tasks_guard = self.tasks.lock().await; 346 | tasks_guard.drain(..).collect::>() 347 | }; 348 | 349 | match tokio::time::timeout(timeout_duration, futures::future::join_all(tasks)).await { 350 | Ok(results) => { 351 | let mut failed = 0; 352 | for (i, result) in results.into_iter().enumerate() { 353 | if result.is_err() { 354 | error!("Task {} failed during shutdown: {}", i, result.unwrap_err()); 355 | failed += 1; 356 | } 357 | } 358 | if failed > 0 { 359 | error!("{} tasks failed during shutdown", failed); 360 | } else { 361 | info!("All tasks completed successfully"); 362 | } 363 | } 364 | Err(_) => { 365 | error!( 366 | "Timeout waiting for tasks to complete after {:?}", 367 | timeout_duration 368 | ); 369 | } 370 | } 371 | 372 | let handle = { 373 | let mut guard = self.stats_manager_handle.lock().await; 374 | guard.take() 375 | }; 376 | 377 | // 6. Wait for stats manager to complete 378 | let _ = self.stats_manager_shutdown.send(true); 379 | 380 | if let Some(handle) = handle { 381 | match handle.await { 382 | Ok(Ok(())) => {} 383 | Ok(Err(e)) => { 384 | error!( 385 | "Stats manager failed to shutdown cleanly: inner error = {}", 386 | e 387 | ); 388 | } 389 | Err(e) => { 390 | error!( 391 | "Stats manager failed to shutdown cleanly: join error = {}", 392 | e 393 | ); 394 | } 395 | } 396 | } 397 | 398 | info!("Shutdown complete"); 399 | Ok(()) 400 | } 401 | } 402 | 403 | async fn read_frame( 404 | reader: &mut tokio::net::tcp::ReadHalf<'_>, 405 | peer_addr: &SocketAddr, 406 | trace_frames: bool, 407 | ) -> Result<(Vec, [u8; 2]), RelayError> { 408 | let mut tcp_buf = vec![0u8; 256]; 409 | 410 | // Read TCP request with timeout 411 | let n = match timeout(Duration::from_secs(60), reader.read(&mut tcp_buf)).await { 412 | Ok(Ok(0)) => { 413 | return Err(RelayError::Connection(ConnectionError::Disconnected)); 414 | } 415 | Ok(Ok(n)) => { 416 | if n < 7 { 417 | return Err(RelayError::frame( 418 | FrameErrorKind::TooShort, 419 | format!("Frame too short: {} bytes", n), 420 | Some(tcp_buf[..n].to_vec()), 421 | )); 422 | } 423 | n 424 | } 425 | Ok(Err(e)) => { 426 | return Err(RelayError::Connection(ConnectionError::InvalidState( 427 | format!("Connection lost: {}", e), 428 | ))); 429 | } 430 | Err(_) => { 431 | return Err(RelayError::Connection(ConnectionError::Timeout( 432 | "Read operation timed out".to_string(), 433 | ))); 434 | } 435 | }; 436 | 437 | if trace_frames { 438 | trace!( 439 | "Received TCP frame from {}: {:02X?}", 440 | peer_addr, 441 | &tcp_buf[..n] 442 | ); 443 | } 444 | 445 | // Validate MBAP header 446 | let transaction_id = [tcp_buf[0], tcp_buf[1]]; 447 | let protocol_id = u16::from_be_bytes([tcp_buf[2], tcp_buf[3]]); 448 | if protocol_id != 0 { 449 | return Err(RelayError::protocol( 450 | ProtocolErrorKind::InvalidProtocolId, 451 | format!("Invalid protocol ID: {}", protocol_id), 452 | )); 453 | } 454 | 455 | let length = u16::from_be_bytes([tcp_buf[4], tcp_buf[5]]) as usize; 456 | if length > 249 { 457 | return Err(RelayError::frame( 458 | FrameErrorKind::TooLong, 459 | format!("Frame too long: {} bytes", length), 460 | None, 461 | )); 462 | } 463 | 464 | if length + 6 != n { 465 | return Err(RelayError::frame( 466 | FrameErrorKind::InvalidFormat, 467 | format!("Invalid frame length, expected {}, got {}", length + 6, n), 468 | Some(tcp_buf[..n].to_vec()), 469 | )); 470 | } 471 | 472 | Ok((tcp_buf[..n].to_vec(), transaction_id)) 473 | } 474 | 475 | async fn process_frame( 476 | modbus: &ModbusProcessor, 477 | frame: &[u8], 478 | transaction_id: [u8; 2], 479 | trace_frames: bool, 480 | ) -> Result, RelayError> { 481 | modbus 482 | .process_request( 483 | transaction_id, 484 | frame[6], // Unit ID 485 | &frame[7..], // PDU 486 | trace_frames, 487 | ) 488 | .await 489 | } 490 | 491 | async fn send_response( 492 | writer: &mut tokio::net::tcp::WriteHalf<'_>, 493 | response: &[u8], 494 | peer_addr: &SocketAddr, 495 | trace_frames: bool, 496 | ) -> Result<(), RelayError> { 497 | if trace_frames { 498 | trace!("Sending TCP response to {}: {:02X?}", peer_addr, response); 499 | } 500 | 501 | // Send TCP response with timeout 502 | match timeout(Duration::from_secs(5), writer.write_all(response)).await { 503 | Ok(Ok(_)) => Ok(()), 504 | Ok(Err(e)) => Err(RelayError::client( 505 | ClientErrorKind::WriteError, 506 | *peer_addr, 507 | format!("Write error: {}", e), 508 | )), 509 | Err(_) => Err(RelayError::client( 510 | ClientErrorKind::Timeout, 511 | *peer_addr, 512 | "Write timeout".to_string(), 513 | )), 514 | } 515 | } 516 | 517 | async fn handle_frame( 518 | reader: &mut tokio::net::tcp::ReadHalf<'_>, 519 | writer: &mut tokio::net::tcp::WriteHalf<'_>, 520 | peer_addr: &SocketAddr, 521 | modbus: &ModbusProcessor, 522 | stats_tx: &mpsc::Sender, 523 | trace_frames: bool, 524 | ) -> Result { 525 | let frame_start = Instant::now(); 526 | 527 | // 1. Read frame 528 | let (frame, transaction_id) = match read_frame(reader, peer_addr, trace_frames).await { 529 | Ok((frame, id)) => (frame, id), 530 | Err(RelayError::Connection(ConnectionError::Disconnected)) => { 531 | info!("Client {} disconnected", peer_addr); 532 | return Ok(false); // Signal to break the loop 533 | } 534 | Err(e) => { 535 | stats_tx 536 | .send(StatEvent::RequestProcessed { 537 | addr: *peer_addr, 538 | success: false, 539 | duration_ms: frame_start.elapsed().as_millis() as u64, 540 | }) 541 | .await 542 | .map_err(|e| { 543 | warn!("Failed to send stats event: {}", e); 544 | }) 545 | .ok(); 546 | 547 | return Err(e); 548 | } 549 | }; 550 | 551 | // 2. Process frame 552 | let response = match process_frame(modbus, &frame, transaction_id, trace_frames).await { 553 | Ok(response) => { 554 | // Record successful Modbus request 555 | stats_tx 556 | .send(StatEvent::RequestProcessed { 557 | addr: *peer_addr, 558 | success: true, 559 | duration_ms: frame_start.elapsed().as_millis() as u64, 560 | }) 561 | .await 562 | .map_err(|e| { 563 | warn!("Failed to send stats event: {}", e); 564 | }) 565 | .ok(); 566 | 567 | response 568 | } 569 | Err(e) => { 570 | // Record failed Modbus request 571 | stats_tx 572 | .send(StatEvent::RequestProcessed { 573 | addr: *peer_addr, 574 | success: false, 575 | duration_ms: frame_start.elapsed().as_millis() as u64, 576 | }) 577 | .await 578 | .map_err(|e| { 579 | warn!("Failed to send stats event: {}", e); 580 | }) 581 | .ok(); 582 | 583 | return Err(e); 584 | } 585 | }; 586 | 587 | // 3. Send response 588 | if let Err(e) = send_response(writer, &response, peer_addr, trace_frames).await { 589 | stats_tx 590 | .send(StatEvent::RequestProcessed { 591 | addr: *peer_addr, 592 | success: false, 593 | duration_ms: frame_start.elapsed().as_millis() as u64, 594 | }) 595 | .await 596 | .map_err(|e| { 597 | warn!("Failed to send stats event: {}", e); 598 | }) 599 | .ok(); 600 | 601 | return Err(e); 602 | } 603 | 604 | Ok(true) // Continue the loop 605 | } 606 | 607 | async fn handle_client( 608 | mut stream: TcpStream, 609 | peer_addr: SocketAddr, 610 | transport: Arc, 611 | manager: Arc, 612 | stats_tx: mpsc::Sender, 613 | mut shutdown_rx: broadcast::Receiver<()>, 614 | trace_frames: bool, 615 | ) -> Result<(), RelayError> { 616 | // Create connection guard to track this connection 617 | let _guard = manager.accept_connection(peer_addr).await?; 618 | 619 | let request_id = generate_request_id(); 620 | 621 | let client_span = tracing::info_span!( 622 | "client_connection", 623 | %peer_addr, 624 | request_id = %request_id, 625 | protocol = "modbus_tcp" 626 | ); 627 | let _enter = client_span.enter(); 628 | 629 | let addr = stream.peer_addr().map_err(|e| { 630 | RelayError::Transport(TransportError::Io { 631 | operation: IoOperation::Control, 632 | details: "Failed to get peer address".to_string(), 633 | source: e, 634 | }) 635 | })?; 636 | 637 | debug!("New client connected from {}", addr); 638 | 639 | let (mut reader, mut writer) = stream.split(); 640 | let modbus = ModbusProcessor::new(transport); 641 | 642 | loop { 643 | tokio::select! { 644 | result = handle_frame(&mut reader, &mut writer, &peer_addr, &modbus, &stats_tx, trace_frames) => { 645 | match result { 646 | Ok(true) => continue, 647 | Ok(false) => break, // Client disconnected 648 | Err(e) => return Err(e), 649 | } 650 | } 651 | _ = shutdown_rx.recv() => { 652 | info!("Client {} received shutdown signal", peer_addr); 653 | break; 654 | } 655 | } 656 | } 657 | 658 | debug!("Client {} disconnected", peer_addr); 659 | 660 | Ok(()) 661 | } 662 | 663 | #[cfg(test)] 664 | mod tests { 665 | // use super::*; 666 | 667 | // #[tokio::test] 668 | // Disabled for now, needs port mocking 669 | // async fn test_modbus_relay_shutdown() { 670 | // let mut config = RelayConfig::default(); 671 | // config.rtu.device = "/dev/null".to_string(); 672 | // let relay = ModbusRelay::new(config).unwrap(); 673 | 674 | // assert!(relay.shutdown().await.is_ok()); 675 | // } 676 | } 677 | --------------------------------------------------------------------------------