diff --git a/rust/Cargo.lock b/rust/Cargo.lock index e3e8ee8..c3713de 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -170,7 +170,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.53.2", + "windows-targets 0.52.6", ] [[package]] diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3a1c4fc..7fc60b0 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,4 +1,3 @@ - [package] name = "envoy-proxy-dynamic-modules-rust-sdk-examples" version = "0.1.0" @@ -20,4 +19,4 @@ tempfile = "3.16.0" [lib] name = "rust_module" path = "src/lib.rs" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] diff --git a/rust/src/http_zero_copy_regex_waf.rs b/rust/src/http_zero_copy_regex_waf.rs index 7d0f29a..aaf7278 100644 --- a/rust/src/http_zero_copy_regex_waf.rs +++ b/rust/src/http_zero_copy_regex_waf.rs @@ -126,7 +126,7 @@ mod tests { #[test] /// This demonstrates how to write a test without Envoy using a mock provided by the SDK. fn test_filter() { - let mut filter_config = FilterConfig::new("Hello [Ww].+").unwrap(); + let filter_config = FilterConfig::new("Hello [Ww].+").unwrap(); let mut envoy_filter = MockEnvoyHttpFilter::new(); let mut filter: Box> = filter_config.new_http_filter(&mut envoy_filter); diff --git a/rust/src/lib.rs b/rust/src/lib.rs index b10c2e4..d87865d 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,43 @@ +//! Envoy Dynamic Modules Rust SDK Examples +//! +//! This crate contains example implementations of Envoy dynamic modules using the Rust SDK. +//! +//! # HTTP Filters +//! +//! The main library exports HTTP filter examples that work with `declare_init_functions!`: +//! - `passthrough` - A minimal filter that passes all data through unchanged. +//! - `access_logger` - Logs request/response information. +//! - `random_auth` - Randomly rejects requests (for testing). +//! - `zero_copy_regex_waf` - Zero-copy regex-based WAF filter. +//! - `header_mutation` - Adds/removes/modifies headers. +//! - `metrics` - Collects request/response metrics. +//! +//! # Network Filters +//! +//! Network filter examples are provided as public modules. To use them, create a separate +//! crate that includes this library and uses `declare_network_filter_init_functions!` with +//! the module's `new_filter_config` function. +//! +//! Available network filters: +//! - [`network_echo`] - Echoes data back to the client. +//! - [`network_rate_limiter`] - Limits concurrent connections. +//! - [`network_protocol_logger`] - Logs protocol information. +//! - [`network_redis`] - Redis RESP protocol parser and command filter. +//! +//! # Listener Filters +//! +//! Listener filter examples are provided as public modules. To use them, create a separate +//! crate that includes this library and uses `declare_listener_filter_init_functions!` with +//! the module's `new_filter_config` function. +//! +//! Available listener filters: +//! - [`listener_ip_allowlist`] - IP allowlist/blocklist filter. +//! - [`listener_tls_detector`] - TLS protocol detection filter. +//! - [`listener_sni_router`] - SNI-based routing filter. + use envoy_proxy_dynamic_modules_rust_sdk::*; +// HTTP filter examples. mod http_access_logger; mod http_header_mutation; mod http_metrics; @@ -7,6 +45,21 @@ mod http_passthrough; mod http_random_auth; mod http_zero_copy_regex_waf; +// Network filter examples. +// These modules can be used to create standalone network filter cdylibs. +// See each module's documentation for usage instructions. +pub mod network_echo; +pub mod network_protocol_logger; +pub mod network_rate_limiter; +pub mod network_redis; + +// Listener filter examples. +// These modules can be used to create standalone listener filter cdylibs. +// See each module's documentation for usage instructions. +pub mod listener_ip_allowlist; +pub mod listener_sni_router; +pub mod listener_tls_detector; + declare_init_functions!(init, new_http_filter_config_fn); /// This implements the [`envoy_proxy_dynamic_modules_rust_sdk::ProgramInitFunction`]. diff --git a/rust/src/listener_ip_allowlist.rs b/rust/src/listener_ip_allowlist.rs new file mode 100644 index 0000000..4ca1736 --- /dev/null +++ b/rust/src/listener_ip_allowlist.rs @@ -0,0 +1,372 @@ +//! An IP allowlist/blocklist filter for connection-level access control. +//! +//! This filter demonstrates: +//! 1. Inspecting connection addresses before the connection is established. +//! 2. IP allowlist/blocklist rules with CIDR support. +//! 3. Working with IPv4 and IPv6 addresses. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "mode": "allowlist", +//! "addresses": ["192.168.1.0/24", "10.0.0.1"], +//! "log_blocked": true +//! } +//! ``` +//! +//! Modes: +//! - "allowlist": Only allow connections from listed addresses (block all others). +//! - "blocklist": Block connections from listed addresses (allow all others). +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_listener_filter_init_functions!(init, listener_ip_allowlist::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::{Deserialize, Serialize}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +/// Filter mode. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum FilterMode { + Allowlist, + Blocklist, +} + +/// Configuration data parsed from the filter config JSON. +#[derive(Serialize, Deserialize, Debug, Clone)] +struct IpAllowlistConfigData { + /// Filter mode: allowlist or blocklist. + mode: FilterMode, + /// List of IP addresses or CIDR ranges. + addresses: Vec, + /// Whether to log blocked connections. + #[serde(default = "default_true")] + log_blocked: bool, +} + +fn default_true() -> bool { + true +} + +/// Parsed IP rule - either a single IP or a CIDR range. +#[derive(Debug, Clone)] +pub enum IpRule { + Single(IpAddr), + CidrV4 { network: u32, prefix_len: u8 }, + CidrV6 { network: u128, prefix_len: u8 }, +} + +impl IpRule { + /// Parse an IP rule from a string. + pub fn parse(s: &str) -> Option { + if let Some((ip_str, prefix_str)) = s.split_once('/') { + // CIDR notation. + let prefix_len: u8 = prefix_str.parse().ok()?; + + if let Ok(ip) = ip_str.parse::() { + if prefix_len > 32 { + return None; + } + let network = u32::from(ip); + return Some(IpRule::CidrV4 { + network, + prefix_len, + }); + } + + if let Ok(ip) = ip_str.parse::() { + if prefix_len > 128 { + return None; + } + let network = u128::from(ip); + return Some(IpRule::CidrV6 { + network, + prefix_len, + }); + } + + None + } else { + // Single IP address. + s.parse::().ok().map(IpRule::Single) + } + } + + /// Check if an IP address matches this rule. + pub fn matches(&self, ip: &IpAddr) -> bool { + match (self, ip) { + (IpRule::Single(rule_ip), addr) => rule_ip == addr, + ( + IpRule::CidrV4 { + network, + prefix_len, + }, + IpAddr::V4(addr), + ) => { + let addr_bits = u32::from(*addr); + let mask = if *prefix_len == 0 { + 0 + } else { + !0u32 << (32 - prefix_len) + }; + (addr_bits & mask) == (network & mask) + } + ( + IpRule::CidrV6 { + network, + prefix_len, + }, + IpAddr::V6(addr), + ) => { + let addr_bits = u128::from(*addr); + let mask = if *prefix_len == 0 { + 0 + } else { + !0u128 << (128 - prefix_len) + }; + (addr_bits & mask) == (network & mask) + } + _ => false, // IPv4 rule vs IPv6 address or vice versa. + } + } +} + +/// The filter configuration. +pub struct IpAllowlistFilterConfig { + mode: FilterMode, + rules: Vec, + log_blocked: bool, + allowed_connections: EnvoyCounterId, + blocked_connections: EnvoyCounterId, +} + +/// Creates a new IP allowlist filter configuration. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: IpAllowlistConfigData = match serde_json::from_slice(config) { + Ok(cfg) => cfg, + Err(err) => { + eprintln!("Error parsing IP allowlist config: {err}"); + return None; + } + }; + + // Parse IP rules. + let mut rules = Vec::new(); + for addr in &config_data.addresses { + match IpRule::parse(addr) { + Some(rule) => rules.push(rule), + None => { + eprintln!("Invalid IP address or CIDR: {addr}"); + return None; + } + } + } + + if rules.is_empty() { + eprintln!("At least one IP address is required"); + return None; + } + + let allowed_connections = envoy_filter_config + .define_counter("ip_filter_allowed_connections_total") + .expect("Failed to define allowed_connections counter"); + + let blocked_connections = envoy_filter_config + .define_counter("ip_filter_blocked_connections_total") + .expect("Failed to define blocked_connections counter"); + + Some(Box::new(IpAllowlistFilterConfig { + mode: config_data.mode, + rules, + log_blocked: config_data.log_blocked, + allowed_connections, + blocked_connections, + })) +} + +impl ListenerFilterConfig for IpAllowlistFilterConfig { + fn new_listener_filter(&self, _envoy: &mut ELF) -> Box> { + Box::new(IpAllowlistFilter { + mode: self.mode.clone(), + rules: self.rules.clone(), + log_blocked: self.log_blocked, + allowed_connections: self.allowed_connections, + blocked_connections: self.blocked_connections, + }) + } +} + +/// The IP allowlist filter. +struct IpAllowlistFilter { + mode: FilterMode, + rules: Vec, + log_blocked: bool, + allowed_connections: EnvoyCounterId, + blocked_connections: EnvoyCounterId, +} + +impl IpAllowlistFilter { + /// Check if an IP address matches any of the rules. + fn matches_any_rule(&self, ip: &IpAddr) -> bool { + self.rules.iter().any(|rule| rule.matches(ip)) + } + + /// Determine if a connection should be allowed based on mode and rules. + fn should_allow(&self, ip: &IpAddr) -> bool { + let matches = self.matches_any_rule(ip); + match self.mode { + FilterMode::Allowlist => matches, // Allow only if in list. + FilterMode::Blocklist => !matches, // Allow only if NOT in list. + } + } +} + +impl ListenerFilter for IpAllowlistFilter { + fn on_accept( + &mut self, + envoy_filter: &mut ELF, + ) -> abi::envoy_dynamic_module_type_on_listener_filter_status { + // Get the remote address. + let (addr_str, port) = match envoy_filter.get_remote_address() { + Some(addr) => addr, + None => { + // If we can't get the address, allow by default. + envoy_log_warn!("Could not get remote address. Allowing connection."); + return abi::envoy_dynamic_module_type_on_listener_filter_status::Continue; + } + }; + + // Parse the IP address. + let ip: IpAddr = match addr_str.parse() { + Ok(ip) => ip, + Err(_) => { + envoy_log_warn!( + "Could not parse IP address: {}. Allowing connection.", + addr_str + ); + return abi::envoy_dynamic_module_type_on_listener_filter_status::Continue; + } + }; + + if self.should_allow(&ip) { + let _ = envoy_filter.increment_counter(self.allowed_connections, 1); + envoy_log_debug!("Connection from {}:{} allowed", addr_str, port); + abi::envoy_dynamic_module_type_on_listener_filter_status::Continue + } else { + let _ = envoy_filter.increment_counter(self.blocked_connections, 1); + + if self.log_blocked { + let mode_str = match self.mode { + FilterMode::Allowlist => "not in allowlist", + FilterMode::Blocklist => "in blocklist", + }; + envoy_log_warn!( + "Connection from {}:{} blocked ({})", + addr_str, + port, + mode_str + ); + } + + // Close the socket to reject the connection. + envoy_filter.set_downstream_transport_failure_reason("IP address blocked by filter"); + + abi::envoy_dynamic_module_type_on_listener_filter_status::Continue + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ip_allowlist_config_parsing() { + let config = r#"{"mode": "allowlist", "addresses": ["192.168.1.0/24", "10.0.0.1"]}"#; + let config_data: IpAllowlistConfigData = serde_json::from_str(config).unwrap(); + assert_eq!(config_data.mode, FilterMode::Allowlist); + assert_eq!(config_data.addresses.len(), 2); + } + + #[test] + fn test_ip_blocklist_config_parsing() { + let config = r#"{"mode": "blocklist", "addresses": ["10.0.0.0/8"], "log_blocked": false}"#; + let config_data: IpAllowlistConfigData = serde_json::from_str(config).unwrap(); + assert_eq!(config_data.mode, FilterMode::Blocklist); + assert!(!config_data.log_blocked); + } + + #[test] + fn test_ip_rule_parse_single_ipv4() { + let rule = IpRule::parse("192.168.1.1").unwrap(); + assert!(matches!(rule, IpRule::Single(IpAddr::V4(_)))); + } + + #[test] + fn test_ip_rule_parse_cidr_ipv4() { + let rule = IpRule::parse("192.168.1.0/24").unwrap(); + assert!(matches!(rule, IpRule::CidrV4 { .. })); + } + + #[test] + fn test_ip_rule_parse_single_ipv6() { + let rule = IpRule::parse("::1").unwrap(); + assert!(matches!(rule, IpRule::Single(IpAddr::V6(_)))); + } + + #[test] + fn test_ip_rule_parse_cidr_ipv6() { + let rule = IpRule::parse("2001:db8::/32").unwrap(); + assert!(matches!(rule, IpRule::CidrV6 { .. })); + } + + #[test] + fn test_ip_rule_parse_invalid() { + assert!(IpRule::parse("invalid").is_none()); + assert!(IpRule::parse("192.168.1.0/33").is_none()); + assert!(IpRule::parse("::1/129").is_none()); + } + + #[test] + fn test_ip_rule_matches_single() { + let rule = IpRule::parse("192.168.1.1").unwrap(); + let ip: IpAddr = "192.168.1.1".parse().unwrap(); + let other_ip: IpAddr = "192.168.1.2".parse().unwrap(); + + assert!(rule.matches(&ip)); + assert!(!rule.matches(&other_ip)); + } + + #[test] + fn test_ip_rule_matches_cidr() { + let rule = IpRule::parse("192.168.1.0/24").unwrap(); + let ip_in_range: IpAddr = "192.168.1.100".parse().unwrap(); + let ip_out_of_range: IpAddr = "192.168.2.1".parse().unwrap(); + + assert!(rule.matches(&ip_in_range)); + assert!(!rule.matches(&ip_out_of_range)); + } + + #[test] + fn test_cidr_edge_cases() { + // /0 should match everything. + let rule_v4 = IpRule::parse("0.0.0.0/0").unwrap(); + let any_ip: IpAddr = "1.2.3.4".parse().unwrap(); + assert!(rule_v4.matches(&any_ip)); + + // /32 should match only exact IP. + let rule_exact = IpRule::parse("192.168.1.1/32").unwrap(); + let exact_ip: IpAddr = "192.168.1.1".parse().unwrap(); + let other_ip: IpAddr = "192.168.1.2".parse().unwrap(); + assert!(rule_exact.matches(&exact_ip)); + assert!(!rule_exact.matches(&other_ip)); + } +} diff --git a/rust/src/listener_sni_router.rs b/rust/src/listener_sni_router.rs new file mode 100644 index 0000000..76d96d5 --- /dev/null +++ b/rust/src/listener_sni_router.rs @@ -0,0 +1,383 @@ +//! An SNI-based routing filter that extracts SNI for filter chain matching. +//! +//! This filter demonstrates: +//! 1. Parsing TLS Client Hello to extract SNI. +//! 2. Domain to cluster mapping with wildcard support. +//! 3. Handling connections with and without SNI. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "default_server_name": "default.example.com", +//! "domain_mappings": { +//! "api.example.com": "api-cluster", +//! "*.example.com": "wildcard-cluster" +//! }, +//! "reject_unknown": false +//! } +//! ``` +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_listener_filter_init_functions!(init, listener_sni_router::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// TLS constants - used by SNI extraction logic and tests. +#[allow(dead_code)] +const TLS_CONTENT_TYPE_HANDSHAKE: u8 = 0x16; +#[allow(dead_code)] +const TLS_HANDSHAKE_CLIENT_HELLO: u8 = 0x01; +#[allow(dead_code)] +const TLS_EXT_SERVER_NAME: u16 = 0x0000; +#[allow(dead_code)] +const SNI_NAME_TYPE_HOSTNAME: u8 = 0x00; + +/// Configuration data parsed from the filter config JSON. +#[derive(Serialize, Deserialize, Debug, Clone)] +struct SniRouterConfigData { + #[serde(default)] + default_server_name: Option, + #[serde(default)] + domain_mappings: HashMap, + #[serde(default)] + reject_unknown: bool, + #[serde(default = "default_max_read_bytes")] + max_read_bytes: usize, +} + +fn default_max_read_bytes() -> usize { + 1024 +} + +impl Default for SniRouterConfigData { + fn default() -> Self { + SniRouterConfigData { + default_server_name: None, + domain_mappings: HashMap::new(), + reject_unknown: false, + max_read_bytes: default_max_read_bytes(), + } + } +} + +/// Result of SNI routing. +#[derive(Debug, Clone, PartialEq)] +pub enum SniRoutingResult { + Routed { sni: String, cluster: String }, + NoMapping { sni: String }, + NoSni { default: Option }, + NotTls, + Rejected { reason: String }, + NeedMoreData, +} + +/// The filter configuration. +pub struct SniRouterFilterConfig { + default_server_name: Option, + domain_mappings: HashMap, + reject_unknown: bool, + max_read_bytes: usize, + sni_matches: EnvoyCounterId, + sni_misses: EnvoyCounterId, +} + +/// Creates a new SNI router filter configuration. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: SniRouterConfigData = if config.is_empty() { + SniRouterConfigData::default() + } else { + match serde_json::from_slice(config) { + Ok(cfg) => cfg, + Err(err) => { + eprintln!("Error parsing SNI router config: {err}"); + return None; + } + } + }; + + let sni_matches = envoy_filter_config + .define_counter("sni_router_matches_total") + .expect("Failed to define sni_matches counter"); + + let sni_misses = envoy_filter_config + .define_counter("sni_router_misses_total") + .expect("Failed to define sni_misses counter"); + + Some(Box::new(SniRouterFilterConfig { + default_server_name: config_data.default_server_name, + domain_mappings: config_data.domain_mappings, + reject_unknown: config_data.reject_unknown, + max_read_bytes: config_data.max_read_bytes, + sni_matches, + sni_misses, + })) +} + +impl ListenerFilterConfig for SniRouterFilterConfig { + fn new_listener_filter(&self, _envoy: &mut ELF) -> Box> { + Box::new(SniRouterFilter { + default_server_name: self.default_server_name.clone(), + domain_mappings: self.domain_mappings.clone(), + reject_unknown: self.reject_unknown, + max_read_bytes: self.max_read_bytes, + sni_matches: self.sni_matches, + sni_misses: self.sni_misses, + }) + } +} + +/// The SNI router filter. +#[allow(dead_code)] +struct SniRouterFilter { + default_server_name: Option, + domain_mappings: HashMap, + reject_unknown: bool, + max_read_bytes: usize, + sni_matches: EnvoyCounterId, + sni_misses: EnvoyCounterId, +} + +#[allow(dead_code)] +impl SniRouterFilter { + /// Extract SNI from TLS Client Hello. + fn extract_sni(&self, data: &[u8]) -> Option { + if data.len() < 6 || data[0] != TLS_CONTENT_TYPE_HANDSHAKE { + return None; + } + + if data[1] != 0x03 { + return None; + } + + if data[5] != TLS_HANDSHAKE_CLIENT_HELLO { + return None; + } + + if data.len() < 43 { + return None; + } + + let mut offset = 9; + offset += 2; // Skip client version. + offset += 32; // Skip client random. + + if offset >= data.len() { + return None; + } + let session_id_len = data[offset] as usize; + offset += 1 + session_id_len; + + if offset + 2 > data.len() { + return None; + } + let cipher_suites_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2 + cipher_suites_len; + + if offset >= data.len() { + return None; + } + let compression_len = data[offset] as usize; + offset += 1 + compression_len; + + if offset + 2 > data.len() { + return None; + } + let extensions_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2; + + let extensions_end = offset + extensions_len; + while offset + 4 <= extensions_end && offset + 4 <= data.len() { + let ext_type = u16::from_be_bytes([data[offset], data[offset + 1]]); + let ext_len = u16::from_be_bytes([data[offset + 2], data[offset + 3]]) as usize; + offset += 4; + + if offset + ext_len > data.len() { + break; + } + + if ext_type == TLS_EXT_SERVER_NAME { + return self.parse_sni_extension(&data[offset..offset + ext_len]); + } + + offset += ext_len; + } + + None + } + + fn parse_sni_extension(&self, data: &[u8]) -> Option { + if data.len() < 5 { + return None; + } + + let mut offset = 2; + + if data[offset] != SNI_NAME_TYPE_HOSTNAME { + return None; + } + offset += 1; + + if offset + 2 > data.len() { + return None; + } + let name_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2; + + if offset + name_len > data.len() { + return None; + } + + String::from_utf8(data[offset..offset + name_len].to_vec()).ok() + } + + /// Look up the cluster for a given SNI. + pub fn lookup_cluster(&self, sni: &str) -> Option<&String> { + // Exact match first. + if let Some(cluster) = self.domain_mappings.get(sni) { + return Some(cluster); + } + + // Wildcard match. + for (domain, cluster) in &self.domain_mappings { + if domain.starts_with("*.") { + let suffix = &domain[1..]; + if sni.ends_with(suffix) && sni.len() > suffix.len() { + return Some(cluster); + } + } + } + + None + } + + /// Process data and determine routing. + pub fn process(&self, data: &[u8]) -> SniRoutingResult { + if data.len() < 6 { + return SniRoutingResult::NeedMoreData; + } + + if data[0] != TLS_CONTENT_TYPE_HANDSHAKE { + return SniRoutingResult::NotTls; + } + + let bytes_to_check = std::cmp::min(data.len(), self.max_read_bytes); + let sni = self.extract_sni(&data[..bytes_to_check]); + + match sni { + Some(server_name) => { + if let Some(cluster) = self.lookup_cluster(&server_name) { + SniRoutingResult::Routed { + sni: server_name, + cluster: cluster.clone(), + } + } else if self.reject_unknown { + SniRoutingResult::Rejected { + reason: format!("Unknown SNI: {server_name}"), + } + } else { + SniRoutingResult::NoMapping { sni: server_name } + } + } + None => { + if self.reject_unknown && self.default_server_name.is_none() { + SniRoutingResult::Rejected { + reason: "Missing SNI".to_string(), + } + } else { + SniRoutingResult::NoSni { + default: self.default_server_name.clone(), + } + } + } + } + } +} + +impl ListenerFilter for SniRouterFilter { + fn on_accept( + &mut self, + envoy_filter: &mut ELF, + ) -> abi::envoy_dynamic_module_type_on_listener_filter_status { + // SNI routing requires inspecting data, which is done in on_data. + envoy_log_debug!("SNI router filter activated"); + let _ = envoy_filter.increment_counter(self.sni_matches, 0); + abi::envoy_dynamic_module_type_on_listener_filter_status::Continue + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sni_router_config_parsing() { + let config = r#"{ + "default_server_name": "default.example.com", + "domain_mappings": { + "api.example.com": "api-cluster" + }, + "reject_unknown": true + }"#; + let config_data: SniRouterConfigData = serde_json::from_str(config).unwrap(); + assert_eq!( + config_data.default_server_name, + Some("default.example.com".to_string()) + ); + assert!(config_data.reject_unknown); + } + + #[test] + fn test_sni_router_default_config() { + let config_data = SniRouterConfigData::default(); + assert!(config_data.default_server_name.is_none()); + assert!(config_data.domain_mappings.is_empty()); + assert!(!config_data.reject_unknown); + } + + #[test] + fn test_sni_routing_result_variants() { + // Test that the SniRoutingResult enum variants are correct. + let routed = SniRoutingResult::Routed { + sni: "api.example.com".to_string(), + cluster: "api-cluster".to_string(), + }; + assert!(matches!(routed, SniRoutingResult::Routed { .. })); + + let no_mapping = SniRoutingResult::NoMapping { + sni: "unknown.com".to_string(), + }; + assert!(matches!(no_mapping, SniRoutingResult::NoMapping { .. })); + + let no_sni = SniRoutingResult::NoSni { default: None }; + assert!(matches!(no_sni, SniRoutingResult::NoSni { .. })); + + let not_tls = SniRoutingResult::NotTls; + assert_eq!(not_tls, SniRoutingResult::NotTls); + + let need_more = SniRoutingResult::NeedMoreData; + assert_eq!(need_more, SniRoutingResult::NeedMoreData); + } + + #[test] + fn test_wildcard_domain_matching_logic() { + // Test the wildcard matching pattern. + let domain = "*.example.com"; + let suffix = &domain[1..]; // ".example.com" + + let test_sni = "api.example.com"; + assert!(test_sni.ends_with(suffix)); + assert!(test_sni.len() > suffix.len()); + + let bad_sni = "example.com"; // Should not match - needs subdomain. + assert!(bad_sni.ends_with(suffix) == false || bad_sni.len() <= suffix.len()); + } +} diff --git a/rust/src/listener_tls_detector.rs b/rust/src/listener_tls_detector.rs new file mode 100644 index 0000000..51d2010 --- /dev/null +++ b/rust/src/listener_tls_detector.rs @@ -0,0 +1,365 @@ +//! A TLS protocol detection filter for listener-level protocol inspection. +//! +//! This filter demonstrates: +//! 1. Inspecting initial connection bytes to detect TLS. +//! 2. Extracting TLS Client Hello information (SNI, ALPN). +//! 3. Protocol detection for filter chain matching. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "min_bytes": 5, +//! "max_bytes": 1024, +//! "extract_sni": true, +//! "extract_alpn": true +//! } +//! ``` +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_listener_filter_init_functions!(init, listener_tls_detector::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::{Deserialize, Serialize}; + +/// TLS constants - used by detection logic and tests. +#[allow(dead_code)] +const TLS_CONTENT_TYPE_HANDSHAKE: u8 = 0x16; +#[allow(dead_code)] +const TLS_HANDSHAKE_CLIENT_HELLO: u8 = 0x01; +#[allow(dead_code)] +const TLS_EXT_SERVER_NAME: u16 = 0x0000; +#[allow(dead_code)] +const TLS_EXT_ALPN: u16 = 0x0010; +#[allow(dead_code)] +const SNI_NAME_TYPE_HOSTNAME: u8 = 0x00; + +/// Configuration data parsed from the filter config JSON. +#[derive(Serialize, Deserialize, Debug, Clone)] +struct TlsDetectorConfigData { + #[serde(default = "default_min_bytes")] + min_bytes: usize, + #[serde(default = "default_max_bytes")] + max_bytes: usize, + #[serde(default = "default_true")] + extract_sni: bool, + #[serde(default = "default_true")] + extract_alpn: bool, +} + +fn default_min_bytes() -> usize { + 5 +} + +fn default_max_bytes() -> usize { + 1024 +} + +fn default_true() -> bool { + true +} + +impl Default for TlsDetectorConfigData { + fn default() -> Self { + TlsDetectorConfigData { + min_bytes: default_min_bytes(), + max_bytes: default_max_bytes(), + extract_sni: true, + extract_alpn: true, + } + } +} + +/// TLS detection result. +#[derive(Debug, Clone, PartialEq)] +pub enum TlsDetectionResult { + Tls { + sni: Option, + alpn: Vec, + }, + NotTls, + NeedMoreData, +} + +/// The filter configuration. +pub struct TlsDetectorFilterConfig { + min_bytes: usize, + max_bytes: usize, + extract_sni: bool, + extract_alpn: bool, + tls_connections: EnvoyCounterId, + non_tls_connections: EnvoyCounterId, +} + +/// Creates a new TLS detector filter configuration. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: TlsDetectorConfigData = if config.is_empty() { + TlsDetectorConfigData::default() + } else { + match serde_json::from_slice(config) { + Ok(cfg) => cfg, + Err(err) => { + eprintln!("Error parsing TLS detector config: {err}"); + return None; + } + } + }; + + let tls_connections = envoy_filter_config + .define_counter("tls_detector_tls_connections_total") + .expect("Failed to define tls_connections counter"); + + let non_tls_connections = envoy_filter_config + .define_counter("tls_detector_non_tls_connections_total") + .expect("Failed to define non_tls_connections counter"); + + Some(Box::new(TlsDetectorFilterConfig { + min_bytes: config_data.min_bytes, + max_bytes: config_data.max_bytes, + extract_sni: config_data.extract_sni, + extract_alpn: config_data.extract_alpn, + tls_connections, + non_tls_connections, + })) +} + +impl ListenerFilterConfig for TlsDetectorFilterConfig { + fn new_listener_filter(&self, _envoy: &mut ELF) -> Box> { + Box::new(TlsDetectorFilter { + min_bytes: self.min_bytes, + max_bytes: self.max_bytes, + extract_sni: self.extract_sni, + extract_alpn: self.extract_alpn, + tls_connections: self.tls_connections, + non_tls_connections: self.non_tls_connections, + }) + } +} + +/// The TLS detector filter. +#[allow(dead_code)] +struct TlsDetectorFilter { + min_bytes: usize, + max_bytes: usize, + extract_sni: bool, + extract_alpn: bool, + tls_connections: EnvoyCounterId, + non_tls_connections: EnvoyCounterId, +} + +#[allow(dead_code)] +impl TlsDetectorFilter { + /// Detect if data is TLS and extract SNI/ALPN. + fn detect(&self, data: &[u8]) -> TlsDetectionResult { + if data.len() < self.min_bytes { + return TlsDetectionResult::NeedMoreData; + } + + // Check for TLS record header. + if data.len() < 6 || data[0] != TLS_CONTENT_TYPE_HANDSHAKE { + return TlsDetectionResult::NotTls; + } + + // Check TLS version. + if data[1] != 0x03 || data[2] > 0x03 { + return TlsDetectionResult::NotTls; + } + + // Check if this is a Client Hello. + if data.len() <= 5 || data[5] != TLS_HANDSHAKE_CLIENT_HELLO { + return TlsDetectionResult::Tls { + sni: None, + alpn: Vec::new(), + }; + } + + // Parse Client Hello. + let bytes_to_read = std::cmp::min(data.len(), self.max_bytes); + let (sni, alpn) = self.parse_client_hello(&data[..bytes_to_read]); + + TlsDetectionResult::Tls { sni, alpn } + } + + /// Parse TLS Client Hello and extract SNI and ALPN. + fn parse_client_hello(&self, data: &[u8]) -> (Option, Vec) { + let mut sni = None; + let mut alpn_protocols = Vec::new(); + + if data.len() < 43 { + return (sni, alpn_protocols); + } + + let mut offset = 9; // Skip TLS record header + handshake header. + offset += 2; // Skip client version. + offset += 32; // Skip client random. + + if offset >= data.len() { + return (sni, alpn_protocols); + } + let session_id_len = data[offset] as usize; + offset += 1 + session_id_len; + + if offset + 2 > data.len() { + return (sni, alpn_protocols); + } + let cipher_suites_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2 + cipher_suites_len; + + if offset >= data.len() { + return (sni, alpn_protocols); + } + let compression_len = data[offset] as usize; + offset += 1 + compression_len; + + if offset + 2 > data.len() { + return (sni, alpn_protocols); + } + let extensions_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2; + + let extensions_end = offset + extensions_len; + while offset + 4 <= extensions_end && offset + 4 <= data.len() { + let ext_type = u16::from_be_bytes([data[offset], data[offset + 1]]); + let ext_len = u16::from_be_bytes([data[offset + 2], data[offset + 3]]) as usize; + offset += 4; + + if offset + ext_len > data.len() { + break; + } + + match ext_type { + TLS_EXT_SERVER_NAME if self.extract_sni => { + sni = self.parse_sni_extension(&data[offset..offset + ext_len]); + } + TLS_EXT_ALPN if self.extract_alpn => { + alpn_protocols = self.parse_alpn_extension(&data[offset..offset + ext_len]); + } + _ => {} + } + + offset += ext_len; + } + + (sni, alpn_protocols) + } + + fn parse_sni_extension(&self, data: &[u8]) -> Option { + if data.len() < 5 { + return None; + } + + let mut offset = 2; // Skip list length. + + if data[offset] != SNI_NAME_TYPE_HOSTNAME { + return None; + } + offset += 1; + + if offset + 2 > data.len() { + return None; + } + let name_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2; + + if offset + name_len > data.len() { + return None; + } + + String::from_utf8(data[offset..offset + name_len].to_vec()).ok() + } + + fn parse_alpn_extension(&self, data: &[u8]) -> Vec { + let mut protocols = Vec::new(); + + if data.len() < 2 { + return protocols; + } + + let mut offset = 2; // Skip list length. + + while offset < data.len() { + let proto_len = data[offset] as usize; + offset += 1; + + if offset + proto_len > data.len() { + break; + } + + if let Ok(proto) = String::from_utf8(data[offset..offset + proto_len].to_vec()) { + protocols.push(proto); + } + offset += proto_len; + } + + protocols + } +} + +impl ListenerFilter for TlsDetectorFilter { + fn on_accept( + &mut self, + envoy_filter: &mut ELF, + ) -> abi::envoy_dynamic_module_type_on_listener_filter_status { + // For TLS detection, we need to inspect the connection data. + // This requires peeking at the socket buffer. + // In a real implementation, this would use on_data callback. + // For now, we just log and continue. + envoy_log_debug!("TLS detector filter activated"); + + let _ = envoy_filter.increment_counter(self.tls_connections, 0); + + abi::envoy_dynamic_module_type_on_listener_filter_status::Continue + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tls_detector_config_parsing() { + let config = r#"{"min_bytes": 10, "max_bytes": 2048}"#; + let config_data: TlsDetectorConfigData = serde_json::from_str(config).unwrap(); + assert_eq!(config_data.min_bytes, 10); + assert_eq!(config_data.max_bytes, 2048); + } + + #[test] + fn test_tls_detector_default_config() { + let config_data = TlsDetectorConfigData::default(); + assert_eq!(config_data.min_bytes, 5); + assert_eq!(config_data.max_bytes, 1024); + assert!(config_data.extract_sni); + assert!(config_data.extract_alpn); + } + + #[test] + fn test_tls_detection_result_variants() { + // Test that the TlsDetectionResult enum variants are correct. + let tls_result = TlsDetectionResult::Tls { + sni: Some("example.com".to_string()), + alpn: vec!["h2".to_string()], + }; + assert!(matches!(tls_result, TlsDetectionResult::Tls { .. })); + + let not_tls = TlsDetectionResult::NotTls; + assert_eq!(not_tls, TlsDetectionResult::NotTls); + + let need_more = TlsDetectionResult::NeedMoreData; + assert_eq!(need_more, TlsDetectionResult::NeedMoreData); + } + + #[test] + fn test_tls_record_header_detection() { + // TLS records start with content type 0x16 (handshake). + let tls_header: &[u8] = &[0x16, 0x03, 0x01]; + assert_eq!(tls_header[0], 0x16); + assert!(tls_header[1] == 0x03); // TLS major version. + } +} diff --git a/rust/src/network_echo.rs b/rust/src/network_echo.rs new file mode 100644 index 0000000..0c19432 --- /dev/null +++ b/rust/src/network_echo.rs @@ -0,0 +1,198 @@ +//! A simple TCP echo filter that echoes back data received from clients. +//! +//! This filter demonstrates: +//! 1. Basic network filter structure with `NetworkFilterConfig` and `NetworkFilter` traits. +//! 2. Reading from and writing to the connection buffer. +//! 3. Tracking metrics (bytes echoed, active connections). +//! +//! Configuration: +//! The configuration is treated as raw bytes that will be used as a prefix for all echoed data. +//! If empty, data is echoed back without modification. +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_network_filter_init_functions!(init, network_echo::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; + +/// The filter configuration that implements +/// [`envoy_proxy_dynamic_modules_rust_sdk::NetworkFilterConfig`]. +/// +/// This configuration is shared across all connections handled by this filter chain. +pub struct EchoFilterConfig { + /// The prefix to prepend to echoed data. + prefix: Vec, + /// Counter for total bytes echoed. + bytes_echoed: EnvoyCounterId, + /// Gauge for current active connections. + active_connections: EnvoyGaugeId, +} + +/// Creates a new echo filter configuration. +/// Config is treated as the raw prefix bytes to use (UTF-8 string). +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + // Use config bytes directly as the prefix. + let prefix = config.to_vec(); + + let bytes_echoed = envoy_filter_config + .define_counter("echo_bytes_total") + .expect("Failed to define bytes_echoed counter"); + + let active_connections = envoy_filter_config + .define_gauge("echo_active_connections") + .expect("Failed to define active_connections gauge"); + + Some(Box::new(EchoFilterConfig { + prefix, + bytes_echoed, + active_connections, + })) +} + +impl NetworkFilterConfig for EchoFilterConfig { + fn new_network_filter(&self, envoy: &mut ENF) -> Box> { + // Increment active connections when a new filter is created. + let _ = envoy.increase_gauge(self.active_connections, 1); + + Box::new(EchoFilter { + prefix: self.prefix.clone(), + bytes_echoed: self.bytes_echoed, + active_connections: self.active_connections, + total_bytes: 0, + }) + } +} + +/// The echo filter that implements [`envoy_proxy_dynamic_modules_rust_sdk::NetworkFilter`]. +/// +/// This filter echoes back all received data to the client, optionally with a prefix. +struct EchoFilter { + /// The prefix to prepend to echoed data. + prefix: Vec, + /// Counter ID for tracking total bytes echoed. + bytes_echoed: EnvoyCounterId, + /// Gauge ID for tracking active connections. + active_connections: EnvoyGaugeId, + /// Total bytes echoed for this connection. + total_bytes: u64, +} + +impl NetworkFilter for EchoFilter { + fn on_new_connection( + &mut self, + envoy_filter: &mut ENF, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + let (addr, port) = envoy_filter.get_remote_address(); + envoy_log_info!("New echo connection from {}:{}", addr, port); + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_read( + &mut self, + envoy_filter: &mut ENF, + data_length: usize, + _end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + if data_length == 0 { + return abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue; + } + + // Get the read buffer chunks. + let (chunks, _total_size) = envoy_filter.get_read_buffer_chunks(); + + // Collect all data from chunks. + let mut data = Vec::with_capacity(data_length); + for chunk in &chunks { + data.extend_from_slice(chunk.as_slice()); + } + + // Drain the read buffer since we've consumed it. + envoy_filter.drain_read_buffer(data.len()); + + // Prepare the response with optional prefix. + let response = if self.prefix.is_empty() { + data + } else { + let mut response = self.prefix.clone(); + response.extend_from_slice(&data); + response + }; + + // Track bytes echoed. + self.total_bytes += response.len() as u64; + let _ = envoy_filter.increment_counter(self.bytes_echoed, response.len() as u64); + + // Write the response back to the client. + envoy_filter.write(&response, false); + + abi::envoy_dynamic_module_type_on_network_filter_data_status::StopIteration + } + + fn on_event( + &mut self, + envoy_filter: &mut ENF, + event: abi::envoy_dynamic_module_type_network_connection_event, + ) { + match event { + abi::envoy_dynamic_module_type_network_connection_event::RemoteClose + | abi::envoy_dynamic_module_type_network_connection_event::LocalClose => { + let _ = envoy_filter.decrease_gauge(self.active_connections, 1); + envoy_log_info!( + "Echo connection closed. Total bytes echoed: {}", + self.total_bytes + ); + } + abi::envoy_dynamic_module_type_network_connection_event::Connected => { + envoy_log_debug!("Echo connection established"); + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_echo_response_no_prefix() { + let data = b"hello world"; + let prefix: Vec = vec![]; + + let response = if prefix.is_empty() { + data.to_vec() + } else { + let mut response = prefix.clone(); + response.extend_from_slice(data); + response + }; + + assert_eq!(response, b"hello world".to_vec()); + } + + #[test] + fn test_echo_response_with_prefix() { + let data = b"hello"; + let prefix = b"ECHO: ".to_vec(); + + let response = if prefix.is_empty() { + data.to_vec() + } else { + let mut response = prefix.clone(); + response.extend_from_slice(data); + response + }; + + assert_eq!(response, b"ECHO: hello".to_vec()); + } + + #[test] + fn test_echo_empty_data() { + let data: &[u8] = b""; + assert!(data.is_empty()); + } +} diff --git a/rust/src/network_protocol_logger.rs b/rust/src/network_protocol_logger.rs new file mode 100644 index 0000000..a18845d --- /dev/null +++ b/rust/src/network_protocol_logger.rs @@ -0,0 +1,479 @@ +//! A protocol detection and logging filter for network connections. +//! +//! This filter demonstrates: +//! 1. Detecting protocols from initial connection bytes. +//! 2. Logging protocol information for debugging. +//! 3. Pattern matching on binary data. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "log_request": true, +//! "log_response": true, +//! "max_log_bytes": 1024, +//! "detect_protocol": true +//! } +//! ``` +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_network_filter_init_functions!(init, network_protocol_logger::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::{Deserialize, Serialize}; + +/// Configuration data parsed from the filter config JSON. +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ProtocolLoggerConfigData { + /// Whether to log request (downstream -> upstream) data. + #[serde(default = "default_true")] + log_request: bool, + /// Whether to log response (upstream -> downstream) data. + #[serde(default = "default_true")] + log_response: bool, + /// Maximum number of bytes to log per direction. + #[serde(default = "default_max_log_bytes")] + max_log_bytes: usize, + /// Whether to attempt protocol detection from first bytes. + #[serde(default = "default_true")] + detect_protocol: bool, +} + +fn default_true() -> bool { + true +} + +fn default_max_log_bytes() -> usize { + 1024 +} + +/// Known protocol signatures for detection. +const TLS_SIGNATURE: &[u8] = &[0x16, 0x03]; // TLS record header. +const HTTP_GET: &[u8] = b"GET "; +const HTTP_POST: &[u8] = b"POST "; +const HTTP_PUT: &[u8] = b"PUT "; +const HTTP_DELETE: &[u8] = b"DELETE "; +const HTTP_HEAD: &[u8] = b"HEAD "; +const MYSQL_HANDSHAKE: u8 = 0x0a; // MySQL initial handshake packet. + +/// Detected protocol type. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Protocol { + Unknown, + Tls, + Http, + Redis, + Mysql, +} + +impl Protocol { + fn as_str(&self) -> &'static str { + match self { + Protocol::Unknown => "unknown", + Protocol::Tls => "tls", + Protocol::Http => "http", + Protocol::Redis => "redis", + Protocol::Mysql => "mysql", + } + } +} + +/// The filter configuration. +struct ProtocolLoggerFilterConfig { + log_request: bool, + log_response: bool, + max_log_bytes: usize, + detect_protocol: bool, + connections_logged: EnvoyCounterId, + request_bytes_histogram: EnvoyHistogramId, + response_bytes_histogram: EnvoyHistogramId, +} + +/// Creates a new protocol logger filter configuration. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: ProtocolLoggerConfigData = if config.is_empty() { + ProtocolLoggerConfigData { + log_request: true, + log_response: true, + max_log_bytes: 1024, + detect_protocol: true, + } + } else { + match serde_json::from_slice(config) { + Ok(cfg) => cfg, + Err(err) => { + eprintln!("Error parsing protocol logger config: {err}"); + return None; + } + } + }; + + let connections_logged = envoy_filter_config + .define_counter("protocol_logger_connections_total") + .expect("Failed to define connections_logged counter"); + + let request_bytes_histogram = envoy_filter_config + .define_histogram("protocol_logger_request_bytes") + .expect("Failed to define request_bytes histogram"); + + let response_bytes_histogram = envoy_filter_config + .define_histogram("protocol_logger_response_bytes") + .expect("Failed to define response_bytes histogram"); + + Some(Box::new(ProtocolLoggerFilterConfig { + log_request: config_data.log_request, + log_response: config_data.log_response, + max_log_bytes: config_data.max_log_bytes, + detect_protocol: config_data.detect_protocol, + connections_logged, + request_bytes_histogram, + response_bytes_histogram, + })) +} + +impl NetworkFilterConfig for ProtocolLoggerFilterConfig { + fn new_network_filter(&self, envoy: &mut ENF) -> Box> { + let _ = envoy.increment_counter(self.connections_logged, 1); + + Box::new(ProtocolLoggerFilter { + log_request: self.log_request, + log_response: self.log_response, + max_log_bytes: self.max_log_bytes, + detect_protocol: self.detect_protocol, + request_bytes_histogram: self.request_bytes_histogram, + response_bytes_histogram: self.response_bytes_histogram, + connection_id: envoy.get_connection_id(), + remote_address: envoy.get_remote_address(), + local_address: envoy.get_local_address(), + detected_protocol: Protocol::Unknown, + total_request_bytes: 0, + total_response_bytes: 0, + first_request_logged: false, + first_response_logged: false, + }) + } +} + +/// The protocol logger filter. +struct ProtocolLoggerFilter { + log_request: bool, + log_response: bool, + max_log_bytes: usize, + detect_protocol: bool, + request_bytes_histogram: EnvoyHistogramId, + response_bytes_histogram: EnvoyHistogramId, + connection_id: u64, + remote_address: (String, u32), + local_address: (String, u32), + detected_protocol: Protocol, + total_request_bytes: u64, + total_response_bytes: u64, + first_request_logged: bool, + first_response_logged: bool, +} + +/// Standalone protocol detector - can be tested without SDK dependencies. +pub struct ProtocolDetector; + +impl ProtocolDetector { + /// Detect protocol from the first bytes of data. + pub fn detect(data: &[u8]) -> Protocol { + if data.len() < 2 { + return Protocol::Unknown; + } + + // Check for TLS. + if data.starts_with(TLS_SIGNATURE) { + return Protocol::Tls; + } + + // Check for HTTP methods. + if data.starts_with(HTTP_GET) + || data.starts_with(HTTP_POST) + || data.starts_with(HTTP_PUT) + || data.starts_with(HTTP_DELETE) + || data.starts_with(HTTP_HEAD) + { + return Protocol::Http; + } + + // Check for Redis. + if data[0] == b'*' && data.contains(&b'\r') { + return Protocol::Redis; + } + + // Check for MySQL handshake. + if data.len() >= 5 && data[4] == MYSQL_HANDSHAKE { + return Protocol::Mysql; + } + + Protocol::Unknown + } + + /// Format bytes for logging (hex dump with ASCII). + pub fn format_bytes_for_logging(data: &[u8], max_bytes: usize) -> String { + let truncated = if data.len() > max_bytes { + &data[..max_bytes] + } else { + data + }; + + let hex: String = truncated + .iter() + .map(|b| format!("{:02x}", b)) + .collect::>() + .join(" "); + + let ascii: String = truncated + .iter() + .map(|&b| { + if b.is_ascii_graphic() || b == b' ' { + b as char + } else { + '.' + } + }) + .collect(); + + if data.len() > max_bytes { + format!( + "[{} bytes, showing first {}] HEX: {} ASCII: {}", + data.len(), + max_bytes, + hex, + ascii + ) + } else { + format!("[{} bytes] HEX: {} ASCII: {}", data.len(), hex, ascii) + } + } +} + +impl ProtocolLoggerFilter { + fn detect_protocol_from_bytes(&self, data: &[u8]) -> Protocol { + ProtocolDetector::detect(data) + } + + fn format_bytes_for_logging(&self, data: &[u8]) -> String { + ProtocolDetector::format_bytes_for_logging(data, self.max_log_bytes) + } +} + +impl NetworkFilter for ProtocolLoggerFilter { + fn on_new_connection( + &mut self, + envoy_filter: &mut ENF, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + envoy_log_info!( + "[conn={}] New connection: remote={}:{} local={}:{} ssl={}", + self.connection_id, + self.remote_address.0, + self.remote_address.1, + self.local_address.0, + self.local_address.1, + envoy_filter.is_ssl() + ); + + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_read( + &mut self, + envoy_filter: &mut ENF, + data_length: usize, + end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + self.total_request_bytes += data_length as u64; + + if self.log_request && !self.first_request_logged && data_length > 0 { + let (chunks, _) = envoy_filter.get_read_buffer_chunks(); + + // Collect data for logging. + let mut data = Vec::new(); + for chunk in &chunks { + data.extend_from_slice(chunk.as_slice()); + if data.len() >= self.max_log_bytes { + break; + } + } + + // Detect protocol from first request data. + if self.detect_protocol && self.detected_protocol == Protocol::Unknown { + self.detected_protocol = self.detect_protocol_from_bytes(&data); + + // Store detected protocol in filter state. + let protocol_bytes = self.detected_protocol.as_str().as_bytes(); + envoy_filter.set_filter_state_bytes(b"detected_protocol", protocol_bytes); + + envoy_log_info!( + "[conn={}] Detected protocol: {}", + self.connection_id, + self.detected_protocol.as_str() + ); + } + + let formatted = self.format_bytes_for_logging(&data); + envoy_log_info!( + "[conn={}] REQUEST: {} end_stream={}", + self.connection_id, + formatted, + end_stream + ); + + self.first_request_logged = true; + } + + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_write( + &mut self, + envoy_filter: &mut ENF, + data_length: usize, + end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + self.total_response_bytes += data_length as u64; + + if self.log_response && !self.first_response_logged && data_length > 0 { + let (chunks, _) = envoy_filter.get_write_buffer_chunks(); + + // Collect data for logging. + let mut data = Vec::new(); + for chunk in &chunks { + data.extend_from_slice(chunk.as_slice()); + if data.len() >= self.max_log_bytes { + break; + } + } + + let formatted = self.format_bytes_for_logging(&data); + envoy_log_info!( + "[conn={}] RESPONSE: {} end_stream={}", + self.connection_id, + formatted, + end_stream + ); + + self.first_response_logged = true; + } + + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_event( + &mut self, + envoy_filter: &mut ENF, + event: abi::envoy_dynamic_module_type_network_connection_event, + ) { + match event { + abi::envoy_dynamic_module_type_network_connection_event::RemoteClose => { + envoy_log_info!( + "[conn={}] Connection closed by remote. protocol={} request_bytes={} response_bytes={}", + self.connection_id, + self.detected_protocol.as_str(), + self.total_request_bytes, + self.total_response_bytes + ); + } + abi::envoy_dynamic_module_type_network_connection_event::LocalClose => { + envoy_log_info!( + "[conn={}] Connection closed locally. protocol={} request_bytes={} response_bytes={}", + self.connection_id, + self.detected_protocol.as_str(), + self.total_request_bytes, + self.total_response_bytes + ); + } + abi::envoy_dynamic_module_type_network_connection_event::Connected => { + envoy_log_debug!( + "[conn={}] Upstream connection established", + self.connection_id + ); + } + _ => {} + } + + // Record histograms on connection close. + if matches!( + event, + abi::envoy_dynamic_module_type_network_connection_event::RemoteClose + | abi::envoy_dynamic_module_type_network_connection_event::LocalClose + ) { + let _ = envoy_filter + .record_histogram_value(self.request_bytes_histogram, self.total_request_bytes); + let _ = envoy_filter + .record_histogram_value(self.response_bytes_histogram, self.total_response_bytes); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_protocol_logger_config_parsing() { + let config = r#"{"log_request": true, "log_response": false, "max_log_bytes": 512}"#; + let config_data: ProtocolLoggerConfigData = serde_json::from_str(config).unwrap(); + assert!(config_data.log_request); + assert!(!config_data.log_response); + assert_eq!(config_data.max_log_bytes, 512); + } + + #[test] + fn test_protocol_logger_default_config() { + let config = r#"{}"#; + let config_data: ProtocolLoggerConfigData = serde_json::from_str(config).unwrap(); + assert!(config_data.log_request); + assert!(config_data.log_response); + assert_eq!(config_data.max_log_bytes, 1024); + assert!(config_data.detect_protocol); + } + + #[test] + fn test_protocol_detection_tls() { + let tls_data: &[u8] = &[0x16, 0x03, 0x01, 0x00, 0x05]; + assert_eq!(ProtocolDetector::detect(tls_data), Protocol::Tls); + } + + #[test] + fn test_protocol_detection_http() { + let http_data = b"GET /path HTTP/1.1\r\n"; + assert_eq!(ProtocolDetector::detect(http_data), Protocol::Http); + } + + #[test] + fn test_protocol_detection_redis() { + let redis_data = b"*1\r\n$4\r\nPING\r\n"; + assert_eq!(ProtocolDetector::detect(redis_data), Protocol::Redis); + } + + #[test] + fn test_protocol_detection_unknown() { + let data = b"random data"; + assert_eq!(ProtocolDetector::detect(data), Protocol::Unknown); + } + + #[test] + fn test_format_bytes_for_logging() { + let data = b"hello"; + let formatted = ProtocolDetector::format_bytes_for_logging(data, 1024); + assert!(formatted.contains("5 bytes")); + assert!(formatted.contains("68 65 6c 6c 6f")); // hex for "hello" + } + + #[test] + fn test_protocol_as_str() { + assert_eq!(Protocol::Unknown.as_str(), "unknown"); + assert_eq!(Protocol::Tls.as_str(), "tls"); + assert_eq!(Protocol::Http.as_str(), "http"); + assert_eq!(Protocol::Redis.as_str(), "redis"); + assert_eq!(Protocol::Mysql.as_str(), "mysql"); + } +} diff --git a/rust/src/network_rate_limiter.rs b/rust/src/network_rate_limiter.rs new file mode 100644 index 0000000..943746c --- /dev/null +++ b/rust/src/network_rate_limiter.rs @@ -0,0 +1,294 @@ +//! A simple connection rate limiter for network filters. +//! +//! This filter demonstrates: +//! 1. Connection counting and rate limiting. +//! 2. Shared state across filter instances using atomic counters. +//! 3. Rejecting connections when limits are exceeded. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "max_connections": 100, +//! "reject_message": "Too many connections" +//! } +//! ``` +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_network_filter_init_functions!(init, network_rate_limiter::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// Configuration data parsed from the filter config JSON. +#[derive(Serialize, Deserialize, Debug, Clone)] +struct RateLimiterConfigData { + /// Maximum number of concurrent connections allowed. + max_connections: u64, + /// Message to send when rejecting connections. + #[serde(default = "default_reject_message")] + reject_message: String, +} + +fn default_reject_message() -> String { + "Connection limit exceeded".to_string() +} + +/// Shared state for tracking active connections across all filter instances. +struct SharedConnectionState { + /// Current number of active connections. + active_connections: AtomicU64, +} + +/// The filter configuration that implements +/// [`envoy_proxy_dynamic_modules_rust_sdk::NetworkFilterConfig`]. +struct RateLimiterFilterConfig { + /// Maximum number of concurrent connections allowed. + max_connections: u64, + /// Message to send when rejecting connections. + reject_message: Vec, + /// Shared state for tracking connections. + shared_state: Arc, + /// Counter for total connections accepted. + connections_accepted: EnvoyCounterId, + /// Counter for total connections rejected. + connections_rejected: EnvoyCounterId, + /// Gauge for current active connections. + active_connections_gauge: EnvoyGaugeId, +} + +/// Creates a new rate limiter filter configuration. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: RateLimiterConfigData = match serde_json::from_slice(config) { + Ok(cfg) => cfg, + Err(err) => { + eprintln!("Error parsing rate limiter config: {err}"); + return None; + } + }; + + if config_data.max_connections == 0 { + eprintln!("max_connections must be greater than 0"); + return None; + } + + let connections_accepted = envoy_filter_config + .define_counter("rate_limiter_connections_accepted_total") + .expect("Failed to define connections_accepted counter"); + + let connections_rejected = envoy_filter_config + .define_counter("rate_limiter_connections_rejected_total") + .expect("Failed to define connections_rejected counter"); + + let active_connections_gauge = envoy_filter_config + .define_gauge("rate_limiter_active_connections") + .expect("Failed to define active_connections gauge"); + + Some(Box::new(RateLimiterFilterConfig { + max_connections: config_data.max_connections, + reject_message: config_data.reject_message.into_bytes(), + shared_state: Arc::new(SharedConnectionState { + active_connections: AtomicU64::new(0), + }), + connections_accepted, + connections_rejected, + active_connections_gauge, + })) +} + +impl NetworkFilterConfig for RateLimiterFilterConfig { + fn new_network_filter(&self, _envoy: &mut ENF) -> Box> { + Box::new(RateLimiterFilter { + max_connections: self.max_connections, + reject_message: self.reject_message.clone(), + shared_state: Arc::clone(&self.shared_state), + connections_accepted: self.connections_accepted, + connections_rejected: self.connections_rejected, + active_connections_gauge: self.active_connections_gauge, + connection_counted: false, + }) + } +} + +/// The rate limiter filter that implements [`envoy_proxy_dynamic_modules_rust_sdk::NetworkFilter`]. +struct RateLimiterFilter { + /// Maximum number of concurrent connections allowed. + max_connections: u64, + /// Message to send when rejecting connections. + reject_message: Vec, + /// Shared state for tracking connections. + shared_state: Arc, + /// Counter ID for connections accepted. + connections_accepted: EnvoyCounterId, + /// Counter ID for connections rejected. + connections_rejected: EnvoyCounterId, + /// Gauge ID for active connections. + active_connections_gauge: EnvoyGaugeId, + /// Whether this connection was counted in the active connections. + connection_counted: bool, +} + +impl NetworkFilter for RateLimiterFilter { + fn on_new_connection( + &mut self, + envoy_filter: &mut ENF, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + // Try to increment the connection count. + let current = self + .shared_state + .active_connections + .fetch_add(1, Ordering::SeqCst); + + if current >= self.max_connections { + // Over limit, decrement and reject. + self.shared_state + .active_connections + .fetch_sub(1, Ordering::SeqCst); + + let _ = envoy_filter.increment_counter(self.connections_rejected, 1); + + let (addr, port) = envoy_filter.get_remote_address(); + envoy_log_warn!( + "Connection from {}:{} rejected. Current connections: {}, limit: {}", + addr, + port, + current, + self.max_connections + ); + + // Send rejection message and close the connection. + envoy_filter.write(&self.reject_message, true); + envoy_filter + .close(abi::envoy_dynamic_module_type_network_connection_close_type::FlushWrite); + + return abi::envoy_dynamic_module_type_on_network_filter_data_status::StopIteration; + } + + // Connection accepted. + self.connection_counted = true; + let _ = envoy_filter.increment_counter(self.connections_accepted, 1); + let _ = envoy_filter.set_gauge( + self.active_connections_gauge, + self.shared_state.active_connections.load(Ordering::SeqCst), + ); + + let (addr, port) = envoy_filter.get_remote_address(); + envoy_log_info!( + "Connection from {}:{} accepted. Active connections: {}", + addr, + port, + current + 1 + ); + + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_read( + &mut self, + _envoy_filter: &mut ENF, + _data_length: usize, + _end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + // Pass through all data. + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_write( + &mut self, + _envoy_filter: &mut ENF, + _data_length: usize, + _end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + // Pass through all data. + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_event( + &mut self, + envoy_filter: &mut ENF, + event: abi::envoy_dynamic_module_type_network_connection_event, + ) { + match event { + abi::envoy_dynamic_module_type_network_connection_event::RemoteClose + | abi::envoy_dynamic_module_type_network_connection_event::LocalClose => { + if self.connection_counted { + let previous = self + .shared_state + .active_connections + .fetch_sub(1, Ordering::SeqCst); + let _ = envoy_filter + .set_gauge(self.active_connections_gauge, previous.saturating_sub(1)); + envoy_log_debug!("Connection closed. Active connections: {}", previous - 1); + } + } + _ => {} + } + } +} + +impl Drop for RateLimiterFilter { + fn drop(&mut self) { + // Ensure we decrement the counter if the filter is dropped without on_event being called. + if self.connection_counted { + self.shared_state + .active_connections + .fetch_sub(1, Ordering::SeqCst); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rate_limiter_config_parsing() { + let config = r#"{"max_connections": 50}"#; + let config_data: RateLimiterConfigData = serde_json::from_str(config).unwrap(); + assert_eq!(config_data.max_connections, 50); + assert_eq!(config_data.reject_message, default_reject_message()); + } + + #[test] + fn test_rate_limiter_config_with_message() { + let config = r#"{"max_connections": 10, "reject_message": "Go away!"}"#; + let config_data: RateLimiterConfigData = serde_json::from_str(config).unwrap(); + assert_eq!(config_data.max_connections, 10); + assert_eq!(config_data.reject_message, "Go away!"); + } + + #[test] + fn test_shared_state_atomic_operations() { + let state = SharedConnectionState { + active_connections: AtomicU64::new(0), + }; + + // Simulate accepting connections. + let v1 = state.active_connections.fetch_add(1, Ordering::SeqCst); + assert_eq!(v1, 0); + + let v2 = state.active_connections.fetch_add(1, Ordering::SeqCst); + assert_eq!(v2, 1); + + // Check current value. + assert_eq!(state.active_connections.load(Ordering::SeqCst), 2); + + // Simulate closing a connection. + let v3 = state.active_connections.fetch_sub(1, Ordering::SeqCst); + assert_eq!(v3, 2); + assert_eq!(state.active_connections.load(Ordering::SeqCst), 1); + } + + #[test] + fn test_default_reject_message() { + assert_eq!(default_reject_message(), "Connection limit exceeded"); + } +} diff --git a/rust/src/network_redis.rs b/rust/src/network_redis.rs new file mode 100644 index 0000000..046de16 --- /dev/null +++ b/rust/src/network_redis.rs @@ -0,0 +1,515 @@ +//! A Redis RESP protocol parser and command filter. +//! +//! This filter demonstrates: +//! 1. Parsing the Redis RESP (REdis Serialization Protocol). +//! 2. Command filtering and blocking. +//! 3. Protocol-aware connection handling. +//! +//! Configuration format (JSON): +//! ```json +//! { +//! "blocked_commands": ["FLUSHALL", "FLUSHDB", "DEBUG"], +//! "log_commands": true, +//! "max_command_length": 1024 +//! } +//! ``` +//! +//! To use this filter as a standalone module, create a separate crate with: +//! ```ignore +//! use envoy_proxy_dynamic_modules_rust_sdk::*; +//! declare_network_filter_init_functions!(init, network_redis::new_filter_config); +//! ``` + +use envoy_proxy_dynamic_modules_rust_sdk::*; +use serde::Deserialize; +use std::collections::HashSet; + +/// Configuration data parsed from JSON. +#[derive(Deserialize)] +struct RedisFilterConfigData { + #[serde(default)] + blocked_commands: Vec, + #[serde(default = "default_log_commands")] + log_commands: bool, + #[serde(default = "default_max_command_length")] + max_command_length: usize, +} + +fn default_log_commands() -> bool { + true +} + +fn default_max_command_length() -> usize { + 1024 +} + +impl Default for RedisFilterConfigData { + fn default() -> Self { + RedisFilterConfigData { + blocked_commands: Vec::new(), + log_commands: default_log_commands(), + max_command_length: default_max_command_length(), + } + } +} + +/// Factory function for creating Redis filter configurations. +pub fn new_filter_config( + envoy_filter_config: &mut EC, + _name: &str, + config: &[u8], +) -> Option>> { + let config_data: RedisFilterConfigData = if config.is_empty() { + RedisFilterConfigData::default() + } else { + match serde_json::from_slice(config) { + Ok(c) => c, + Err(e) => { + envoy_log_info!( + "Failed to parse Redis filter config: {}. Using defaults.", + e + ); + RedisFilterConfigData::default() + } + } + }; + + // Define metrics. + let commands_total = envoy_filter_config + .define_counter("redis_commands_total") + .ok()?; + let commands_blocked = envoy_filter_config + .define_counter("redis_commands_blocked") + .ok()?; + let bytes_received = envoy_filter_config + .define_counter("redis_bytes_received") + .ok()?; + let bytes_sent = envoy_filter_config + .define_counter("redis_bytes_sent") + .ok()?; + let active_connections = envoy_filter_config + .define_gauge("redis_active_connections") + .ok()?; + let parse_errors = envoy_filter_config + .define_counter("redis_parse_errors") + .ok()?; + + Some(Box::new(RedisFilterConfig { + blocked_commands: config_data.blocked_commands.into_iter().collect(), + log_commands: config_data.log_commands, + max_command_length: config_data.max_command_length, + commands_total, + commands_blocked, + bytes_received, + bytes_sent, + active_connections, + parse_errors, + })) +} + +/// Redis filter configuration. +struct RedisFilterConfig { + blocked_commands: HashSet, + log_commands: bool, + max_command_length: usize, + commands_total: EnvoyCounterId, + commands_blocked: EnvoyCounterId, + bytes_received: EnvoyCounterId, + bytes_sent: EnvoyCounterId, + active_connections: EnvoyGaugeId, + parse_errors: EnvoyCounterId, +} + +impl NetworkFilterConfig for RedisFilterConfig { + fn new_network_filter(&self, envoy_filter: &mut ENF) -> Box> { + let _ = envoy_filter.increase_gauge(self.active_connections, 1); + + Box::new(RedisFilter { + blocked_commands: self.blocked_commands.clone(), + log_commands: self.log_commands, + max_command_length: self.max_command_length, + commands_total: self.commands_total, + commands_blocked: self.commands_blocked, + bytes_received: self.bytes_received, + bytes_sent: self.bytes_sent, + active_connections: self.active_connections, + parse_errors: self.parse_errors, + }) + } +} + +/// Redis filter instance for a single connection. +struct RedisFilter { + blocked_commands: HashSet, + log_commands: bool, + max_command_length: usize, + commands_total: EnvoyCounterId, + commands_blocked: EnvoyCounterId, + bytes_received: EnvoyCounterId, + bytes_sent: EnvoyCounterId, + active_connections: EnvoyGaugeId, + parse_errors: EnvoyCounterId, +} + +/// Represents a parsed Redis command. +#[derive(Debug, Clone, PartialEq)] +pub struct RedisCommand { + pub name: String, + pub args: Vec>, +} + +/// RESP parsing result. +#[derive(Debug, PartialEq)] +pub enum RespValue { + SimpleString(String), + Error(String), + Integer(i64), + BulkString(Vec), + Array(Vec), + Null, +} + +/// Standalone RESP parser - can be used and tested without SDK dependencies. +pub struct RespParser; + +impl RespParser { + /// Parses RESP data and extracts commands. + pub fn parse_commands(data: &[u8]) -> Result, &'static str> { + let mut commands = Vec::new(); + let mut pos = 0; + + while pos < data.len() { + match Self::parse_resp_value(data, pos) { + Ok((value, new_pos)) => { + if let Some(cmd) = Self::resp_to_command(value) { + commands.push(cmd); + } + pos = new_pos; + } + Err(_) => { + break; + } + } + } + + Ok(commands) + } + + /// Parses a single RESP value starting at the given position. + pub fn parse_resp_value(data: &[u8], pos: usize) -> Result<(RespValue, usize), &'static str> { + if pos >= data.len() { + return Err("Incomplete data"); + } + + let type_byte = data[pos]; + let start = pos + 1; + + match type_byte { + b'+' => { + let (line, end) = Self::read_line(data, start)?; + Ok(( + RespValue::SimpleString(String::from_utf8_lossy(line).to_string()), + end, + )) + } + b'-' => { + let (line, end) = Self::read_line(data, start)?; + Ok(( + RespValue::Error(String::from_utf8_lossy(line).to_string()), + end, + )) + } + b':' => { + let (line, end) = Self::read_line(data, start)?; + let num: i64 = String::from_utf8_lossy(line) + .parse() + .map_err(|_| "Invalid integer")?; + Ok((RespValue::Integer(num), end)) + } + b'$' => { + let (line, end) = Self::read_line(data, start)?; + let len: i64 = String::from_utf8_lossy(line) + .parse() + .map_err(|_| "Invalid bulk string length")?; + + if len < 0 { + return Ok((RespValue::Null, end)); + } + + let len = len as usize; + if end + len + 2 > data.len() { + return Err("Incomplete bulk string"); + } + + let content = data[end..end + len].to_vec(); + Ok((RespValue::BulkString(content), end + len + 2)) + } + b'*' => { + let (line, mut end) = Self::read_line(data, start)?; + let count: i64 = String::from_utf8_lossy(line) + .parse() + .map_err(|_| "Invalid array count")?; + + if count < 0 { + return Ok((RespValue::Null, end)); + } + + let mut elements = Vec::with_capacity(count as usize); + for _ in 0..count { + let (value, new_end) = Self::parse_resp_value(data, end)?; + elements.push(value); + end = new_end; + } + + Ok((RespValue::Array(elements), end)) + } + _ => Err("Unknown RESP type"), + } + } + + /// Reads a line (until \r\n) from the data. + pub fn read_line(data: &[u8], start: usize) -> Result<(&[u8], usize), &'static str> { + if start >= data.len() { + return Err("Start position out of bounds"); + } + for i in start..data.len().saturating_sub(1) { + if data[i] == b'\r' && data[i + 1] == b'\n' { + return Ok((&data[start..i], i + 2)); + } + } + Err("Incomplete line") + } + + /// Converts a RESP value to a Redis command. + pub fn resp_to_command(value: RespValue) -> Option { + match value { + RespValue::Array(elements) if !elements.is_empty() => { + let mut args: Vec> = Vec::new(); + + for elem in elements { + match elem { + RespValue::BulkString(s) => args.push(s), + RespValue::SimpleString(s) => args.push(s.into_bytes()), + _ => {} + } + } + + if args.is_empty() { + return None; + } + + let name = String::from_utf8_lossy(&args[0]).to_uppercase(); + Some(RedisCommand { + name, + args: args.into_iter().skip(1).collect(), + }) + } + _ => None, + } + } + + /// Creates a Redis error response. + pub fn create_error_response(message: &str) -> Vec { + format!("-ERR {}\r\n", message).into_bytes() + } +} + +impl RedisFilter { + fn parse_commands(&self, data: &[u8]) -> Result, &'static str> { + RespParser::parse_commands(data) + } + + fn create_error_response(&self, message: &str) -> Vec { + RespParser::create_error_response(message) + } +} + +impl NetworkFilter for RedisFilter { + fn on_new_connection( + &mut self, + _envoy_filter: &mut ENF, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + envoy_log_debug!("New Redis connection established."); + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_read( + &mut self, + envoy_filter: &mut ENF, + read_buffer_length: usize, + _end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + let _ = envoy_filter.increment_counter(self.bytes_received, read_buffer_length as u64); + + // Get data from read buffer. + let (chunks, _) = envoy_filter.get_read_buffer_chunks(); + + // Collect all bytes. + let mut data = Vec::new(); + for chunk in chunks { + data.extend_from_slice(chunk.as_slice()); + } + + // Check for command length limit. + if data.len() > self.max_command_length { + envoy_log_info!( + "Redis command exceeds max length: {} > {}", + data.len(), + self.max_command_length + ); + let _ = envoy_filter.increment_counter(self.parse_errors, 1); + envoy_filter.drain_read_buffer(read_buffer_length); + envoy_filter.write(&self.create_error_response("command too long"), false); + return abi::envoy_dynamic_module_type_on_network_filter_data_status::StopIteration; + } + + // Parse Redis commands. + match self.parse_commands(&data) { + Ok(commands) => { + for cmd in &commands { + let _ = envoy_filter.increment_counter(self.commands_total, 1); + + if self.log_commands { + let args_str: Vec = cmd + .args + .iter() + .take(3) + .map(|a| String::from_utf8_lossy(a).to_string()) + .collect(); + envoy_log_debug!("Redis command: {} {:?}", cmd.name, args_str); + } + + // Check if command is blocked. + if self.blocked_commands.contains(&cmd.name) { + envoy_log_info!("Blocked Redis command: {}", cmd.name); + let _ = envoy_filter.increment_counter(self.commands_blocked, 1); + + // Drain the read buffer and send error response. + envoy_filter.drain_read_buffer(read_buffer_length); + let error_msg = format!("command '{}' is not allowed", cmd.name); + envoy_filter.write(&self.create_error_response(&error_msg), false); + return abi::envoy_dynamic_module_type_on_network_filter_data_status::StopIteration; + } + } + } + Err(e) => { + envoy_log_debug!("Redis parse error: {}", e); + let _ = envoy_filter.increment_counter(self.parse_errors, 1); + } + } + + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_write( + &mut self, + envoy_filter: &mut ENF, + write_buffer_length: usize, + _end_stream: bool, + ) -> abi::envoy_dynamic_module_type_on_network_filter_data_status { + let _ = envoy_filter.increment_counter(self.bytes_sent, write_buffer_length as u64); + abi::envoy_dynamic_module_type_on_network_filter_data_status::Continue + } + + fn on_event( + &mut self, + envoy_filter: &mut ENF, + event: abi::envoy_dynamic_module_type_network_connection_event, + ) { + match event { + abi::envoy_dynamic_module_type_network_connection_event::RemoteClose + | abi::envoy_dynamic_module_type_network_connection_event::LocalClose => { + let _ = envoy_filter.decrease_gauge(self.active_connections, 1); + envoy_log_debug!("Redis connection closed."); + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_redis_config_default() { + let config = RedisFilterConfigData::default(); + assert!(config.blocked_commands.is_empty()); + assert!(config.log_commands); + assert_eq!(config.max_command_length, 1024); + } + + #[test] + fn test_redis_config_parsing() { + let json = r#"{"blocked_commands": ["FLUSHALL"], "log_commands": false}"#; + let config: RedisFilterConfigData = serde_json::from_str(json).unwrap(); + assert_eq!(config.blocked_commands, vec!["FLUSHALL"]); + assert!(!config.log_commands); + } + + #[test] + fn test_create_error_response() { + let response = RespParser::create_error_response("test error"); + assert_eq!(response, b"-ERR test error\r\n"); + } + + #[test] + fn test_resp_simple_string_parsing() { + let data = b"+OK\r\n"; + let result = RespParser::parse_resp_value(data, 0); + assert!(result.is_ok()); + let (value, consumed) = result.unwrap(); + assert_eq!(value, RespValue::SimpleString("OK".to_string())); + assert_eq!(consumed, 5); + } + + #[test] + fn test_resp_bulk_string_parsing() { + let data = b"$5\r\nhello\r\n"; + let result = RespParser::parse_resp_value(data, 0); + assert!(result.is_ok()); + let (value, _) = result.unwrap(); + assert_eq!(value, RespValue::BulkString(b"hello".to_vec())); + } + + #[test] + fn test_resp_integer_parsing() { + let data = b":1000\r\n"; + let result = RespParser::parse_resp_value(data, 0); + assert!(result.is_ok()); + let (value, _) = result.unwrap(); + assert_eq!(value, RespValue::Integer(1000)); + } + + #[test] + fn test_command_extraction() { + // *2\r\n$3\r\nGET\r\n$3\r\nkey\r\n + let data = b"*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n"; + let commands = RespParser::parse_commands(data).unwrap(); + assert_eq!(commands.len(), 1); + assert_eq!(commands[0].name, "GET"); + assert_eq!(commands[0].args.len(), 1); + assert_eq!(commands[0].args[0], b"key"); + } + + #[test] + fn test_resp_error_parsing() { + let data = b"-ERR something went wrong\r\n"; + let result = RespParser::parse_resp_value(data, 0); + assert!(result.is_ok()); + let (value, _) = result.unwrap(); + assert_eq!( + value, + RespValue::Error("ERR something went wrong".to_string()) + ); + } + + #[test] + fn test_resp_null_bulk_string() { + let data = b"$-1\r\n"; + let result = RespParser::parse_resp_value(data, 0); + assert!(result.is_ok()); + let (value, _) = result.unwrap(); + assert_eq!(value, RespValue::Null); + } +}