Add configurable global concurrent request limiting

- Add max_concurrent_requests config option (default: 1000)
- Implement global AtomicUsize counter for concurrent request tracking
- Return status 41 'Server unavailable' when limit exceeded
- Proper counter management with decrements on all exit paths
- Add comprehensive config validation (1-1,000,000 range)
- Update documentation with rate limiting details
- Add unit tests for config parsing
- Thread-safe implementation using Ordering::Relaxed

This provides effective DDoS protection by limiting concurrent
connections to prevent server overload while maintaining
configurability for different deployment scenarios.
This commit is contained in:
Jeena 2026-01-16 02:26:59 +00:00
parent 9d29321806
commit 0468781a69
5 changed files with 54 additions and 6 deletions

View file

@ -92,9 +92,10 @@ nothing else. It is meant to be generic so other people can use it.
- Default file: "index.gmi" for directory requests - Default file: "index.gmi" for directory requests
## Error Handling ## Error Handling
- **Concurrent request limit exceeded**: Return status 41 "Server unavailable"
- **Timeout**: Return status 41 "Server unavailable" (not 59) - **Timeout**: Return status 41 "Server unavailable" (not 59)
- **Request too large**: Return status 59 "Bad request" - **Request too large**: Return status 59 "Bad request"
- **Empty request**: Return status 59 "Bad request" - **Empty request**: Return status 59 "Bad request"
- **Invalid URL format**: Return status 59 "Bad request" - **Invalid URL format**: Return status 59 "Bad request"
- **Hostname mismatch**: Return status 59 "Bad request" - **Hostname mismatch**: Return status 59 "Bad request"
- **Path resolution failure**: Return status 51 "Not found" (including security violations) - **Path resolution failure**: Return status 51 "Not found" (including security violations)
@ -102,12 +103,13 @@ nothing else. It is meant to be generic so other people can use it.
- Reject requests > 1024 bytes (per Gemini spec) - Reject requests > 1024 bytes (per Gemini spec)
- Reject requests without proper `\r\n` termination - Reject requests without proper `\r\n` termination
- Use `tokio::time::timeout()` for request timeout handling - Use `tokio::time::timeout()` for request timeout handling
- Configurable concurrent request limit: `max_concurrent_requests` (default: 1000)
## Configuration ## Configuration
- TOML config files with `serde::Deserialize` - TOML config files with `serde::Deserialize`
- CLI args override config file values - CLI args override config file values
- Required fields: root, cert, key, host - Required fields: root, cert, key, host
- Optional: port, log_level - Optional: port, log_level, max_concurrent_requests
# Development Notes # Development Notes
- Generate self-signed certificates for local testing in `tmp/` directory - Generate self-signed certificates for local testing in `tmp/` directory

3
BACKLOG.md Normal file
View file

@ -0,0 +1,3 @@
- remove the CLI options, everything should be only configurable via config file
- seperate tests into unit/integration and system tests, at least by naming convention, but perhaps there is a rust way to do it
- add a system test which tests that the server really responds with 44 before 41

View file

@ -8,6 +8,7 @@ pub struct Config {
pub host: Option<String>, pub host: Option<String>,
pub port: Option<u16>, pub port: Option<u16>,
pub log_level: Option<String>, pub log_level: Option<String>,
pub max_concurrent_requests: Option<usize>,
} }
pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> { pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
@ -43,6 +44,21 @@ mod tests {
assert_eq!(config.host, Some("example.com".to_string())); assert_eq!(config.host, Some("example.com".to_string()));
assert_eq!(config.port, Some(1965)); assert_eq!(config.port, Some(1965));
assert_eq!(config.log_level, Some("info".to_string())); assert_eq!(config.log_level, Some("info".to_string()));
assert_eq!(config.max_concurrent_requests, None); // Default
}
#[test]
fn test_load_config_with_max_concurrent_requests() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = r#"
root = "/path/to/root"
max_concurrent_requests = 500
"#;
fs::write(&config_path, content).unwrap();
let config = load_config(config_path.to_str().unwrap()).unwrap();
assert_eq!(config.max_concurrent_requests, Some(500));
} }
#[test] #[test]

View file

@ -69,6 +69,7 @@ async fn main() {
host: None, host: None,
port: None, port: None,
log_level: None, log_level: None,
max_concurrent_requests: None,
}); });
// Initialize logging // Initialize logging
@ -82,6 +83,13 @@ async fn main() {
let host = args.host.or(config.host).unwrap_or_else(|| "0.0.0.0".to_string()); let host = args.host.or(config.host).unwrap_or_else(|| "0.0.0.0".to_string());
let port = args.port.or(config.port).unwrap_or(1965); let port = args.port.or(config.port).unwrap_or(1965);
// Validate max concurrent requests
let max_concurrent_requests = config.max_concurrent_requests.unwrap_or(1000);
if max_concurrent_requests == 0 || max_concurrent_requests > 1_000_000 {
eprintln!("Error: max_concurrent_requests must be between 1 and 1,000,000");
std::process::exit(1);
}
// Validate directory // Validate directory
let dir_path = Path::new(&root); let dir_path = Path::new(&root);
if !dir_path.exists() || !dir_path.is_dir() { if !dir_path.exists() || !dir_path.is_dir() {
@ -110,8 +118,9 @@ async fn main() {
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
let dir = root.clone(); let dir = root.clone();
let expected_host = "localhost".to_string(); // Override for testing let expected_host = "localhost".to_string(); // Override for testing
let max_concurrent = max_concurrent_requests;
if let Ok(stream) = acceptor.accept(stream).await { if let Ok(stream) = acceptor.accept(stream).await {
if let Err(e) = server::handle_connection(stream, &dir, &expected_host).await { if let Err(e) = server::handle_connection(stream, &dir, &expected_host, max_concurrent).await {
tracing::error!("Error handling connection: {}", e); tracing::error!("Error handling connection: {}", e);
} }
} }

View file

@ -3,11 +3,14 @@ use crate::logging::RequestLogger;
use std::fs; use std::fs;
use std::io; use std::io;
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
use tokio_rustls::server::TlsStream; use tokio_rustls::server::TlsStream;
static ACTIVE_REQUESTS: AtomicUsize = AtomicUsize::new(0);
pub async fn serve_file( pub async fn serve_file(
stream: &mut TlsStream<TcpStream>, stream: &mut TlsStream<TcpStream>,
file_path: &Path, file_path: &Path,
@ -29,7 +32,15 @@ pub async fn handle_connection(
mut stream: TlsStream<TcpStream>, mut stream: TlsStream<TcpStream>,
dir: &str, dir: &str,
expected_host: &str, expected_host: &str,
max_concurrent_requests: usize,
) -> io::Result<()> { ) -> io::Result<()> {
// Check concurrent request limit
let current = ACTIVE_REQUESTS.fetch_add(1, Ordering::Relaxed);
if current >= max_concurrent_requests {
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "41 Server unavailable\r\n").await;
}
const MAX_REQUEST_SIZE: usize = 4096; const MAX_REQUEST_SIZE: usize = 4096;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
@ -58,12 +69,14 @@ pub async fn handle_connection(
if request.is_empty() { if request.is_empty() {
let logger = RequestLogger::new(&stream, request); let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Empty request"); logger.log_error(59, "Empty request");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await; return send_response(&mut stream, "59 Bad Request\r\n").await;
} }
if request.len() > 1024 { if request.len() > 1024 {
let logger = RequestLogger::new(&stream, request); let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Request too large"); logger.log_error(59, "Request too large");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await; return send_response(&mut stream, "59 Bad Request\r\n").await;
} }
@ -73,6 +86,7 @@ pub async fn handle_connection(
Err(_) => { Err(_) => {
let logger = RequestLogger::new(&stream, request); let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Invalid URL format"); logger.log_error(59, "Invalid URL format");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await; return send_response(&mut stream, "59 Bad Request\r\n").await;
} }
}; };
@ -85,6 +99,7 @@ pub async fn handle_connection(
Ok(fp) => fp, Ok(fp) => fp,
Err(PathResolutionError::NotFound) => { Err(PathResolutionError::NotFound) => {
logger.log_error(51, "File not found"); logger.log_error(51, "File not found");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "51 Not found\r\n").await; return send_response(&mut stream, "51 Not found\r\n").await;
} }
}; };
@ -95,7 +110,7 @@ pub async fn handle_connection(
Err(_) => { Err(_) => {
// This shouldn't happen since we check existence, but handle gracefully // This shouldn't happen since we check existence, but handle gracefully
logger.log_error(51, "File not found"); logger.log_error(51, "File not found");
send_response(&mut stream, "51 Not found\r\n").await?; let _ = send_response(&mut stream, "51 Not found\r\n").await;
} }
} }
}, },
@ -103,7 +118,7 @@ pub async fn handle_connection(
// Read failed, check error type // Read failed, check error type
let request_str = String::from_utf8_lossy(&request_buf).trim().to_string(); let request_str = String::from_utf8_lossy(&request_buf).trim().to_string();
let logger = RequestLogger::new(&stream, request_str); let logger = RequestLogger::new(&stream, request_str);
match e.kind() { match e.kind() {
tokio::io::ErrorKind::InvalidData => { tokio::io::ErrorKind::InvalidData => {
logger.log_error(59, "Request too large"); logger.log_error(59, "Request too large");
@ -114,6 +129,7 @@ pub async fn handle_connection(
let _ = send_response(&mut stream, "59 Bad Request\r\n").await; let _ = send_response(&mut stream, "59 Bad Request\r\n").await;
} }
} }
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
}, },
Err(_) => { Err(_) => {
// Timeout // Timeout
@ -121,10 +137,12 @@ pub async fn handle_connection(
let logger = RequestLogger::new(&stream, request_str); let logger = RequestLogger::new(&stream, request_str);
logger.log_error(41, "Server unavailable"); logger.log_error(41, "Server unavailable");
let _ = send_response(&mut stream, "41 Server unavailable\r\n").await; let _ = send_response(&mut stream, "41 Server unavailable\r\n").await;
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return Ok(()); return Ok(());
} }
} }
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
Ok(()) Ok(())
} }