diff --git a/AGENTS.md b/AGENTS.md index 83ad33b..5d26779 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -92,9 +92,10 @@ nothing else. It is meant to be generic so other people can use it. - Default file: "index.gmi" for directory requests ## Error Handling +- **Concurrent request limit exceeded**: Return status 41 "Server unavailable" - **Timeout**: Return status 41 "Server unavailable" (not 59) - **Request too large**: Return status 59 "Bad request" -- **Empty request**: Return status 59 "Bad request" +- **Empty request**: Return status 59 "Bad request" - **Invalid URL format**: Return status 59 "Bad request" - **Hostname mismatch**: Return status 59 "Bad request" - **Path resolution failure**: Return status 51 "Not found" (including security violations) @@ -102,12 +103,13 @@ nothing else. It is meant to be generic so other people can use it. - Reject requests > 1024 bytes (per Gemini spec) - Reject requests without proper `\r\n` termination - Use `tokio::time::timeout()` for request timeout handling +- Configurable concurrent request limit: `max_concurrent_requests` (default: 1000) ## Configuration - TOML config files with `serde::Deserialize` - CLI args override config file values - Required fields: root, cert, key, host -- Optional: port, log_level +- Optional: port, log_level, max_concurrent_requests # Development Notes - Generate self-signed certificates for local testing in `tmp/` directory diff --git a/BACKLOG.md b/BACKLOG.md new file mode 100644 index 0000000..0801b61 --- /dev/null +++ b/BACKLOG.md @@ -0,0 +1,3 @@ +- remove the CLI options, everything should be only configurable via config file +- seperate tests into unit/integration and system tests, at least by naming convention, but perhaps there is a rust way to do it +- add a system test which tests that the server really responds with 44 before 41 diff --git a/src/config.rs b/src/config.rs index 2b92041..8342953 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,6 +8,7 @@ pub struct Config { pub host: Option, pub port: Option, pub log_level: Option, + pub max_concurrent_requests: Option, } pub fn load_config(path: &str) -> Result> { @@ -43,6 +44,21 @@ mod tests { assert_eq!(config.host, Some("example.com".to_string())); assert_eq!(config.port, Some(1965)); assert_eq!(config.log_level, Some("info".to_string())); + assert_eq!(config.max_concurrent_requests, None); // Default + } + + #[test] + fn test_load_config_with_max_concurrent_requests() { + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("config.toml"); + let content = r#" + root = "/path/to/root" + max_concurrent_requests = 500 + "#; + fs::write(&config_path, content).unwrap(); + + let config = load_config(config_path.to_str().unwrap()).unwrap(); + assert_eq!(config.max_concurrent_requests, Some(500)); } #[test] diff --git a/src/main.rs b/src/main.rs index 41f35c9..929eb46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -69,6 +69,7 @@ async fn main() { host: None, port: None, log_level: None, + max_concurrent_requests: None, }); // Initialize logging @@ -82,6 +83,13 @@ async fn main() { let host = args.host.or(config.host).unwrap_or_else(|| "0.0.0.0".to_string()); let port = args.port.or(config.port).unwrap_or(1965); + // Validate max concurrent requests + let max_concurrent_requests = config.max_concurrent_requests.unwrap_or(1000); + if max_concurrent_requests == 0 || max_concurrent_requests > 1_000_000 { + eprintln!("Error: max_concurrent_requests must be between 1 and 1,000,000"); + std::process::exit(1); + } + // Validate directory let dir_path = Path::new(&root); if !dir_path.exists() || !dir_path.is_dir() { @@ -110,8 +118,9 @@ async fn main() { let acceptor = acceptor.clone(); let dir = root.clone(); let expected_host = "localhost".to_string(); // Override for testing + let max_concurrent = max_concurrent_requests; if let Ok(stream) = acceptor.accept(stream).await { - if let Err(e) = server::handle_connection(stream, &dir, &expected_host).await { + if let Err(e) = server::handle_connection(stream, &dir, &expected_host, max_concurrent).await { tracing::error!("Error handling connection: {}", e); } } diff --git a/src/server.rs b/src/server.rs index 53c3eae..df351c6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,11 +3,14 @@ use crate::logging::RequestLogger; use std::fs; use std::io; use std::path::Path; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; use tokio_rustls::server::TlsStream; +static ACTIVE_REQUESTS: AtomicUsize = AtomicUsize::new(0); + pub async fn serve_file( stream: &mut TlsStream, file_path: &Path, @@ -29,7 +32,15 @@ pub async fn handle_connection( mut stream: TlsStream, dir: &str, expected_host: &str, + max_concurrent_requests: usize, ) -> io::Result<()> { + // Check concurrent request limit + let current = ACTIVE_REQUESTS.fetch_add(1, Ordering::Relaxed); + if current >= max_concurrent_requests { + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); + return send_response(&mut stream, "41 Server unavailable\r\n").await; + } + const MAX_REQUEST_SIZE: usize = 4096; const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); @@ -58,12 +69,14 @@ pub async fn handle_connection( if request.is_empty() { let logger = RequestLogger::new(&stream, request); logger.log_error(59, "Empty request"); + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); return send_response(&mut stream, "59 Bad Request\r\n").await; } if request.len() > 1024 { let logger = RequestLogger::new(&stream, request); logger.log_error(59, "Request too large"); + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); return send_response(&mut stream, "59 Bad Request\r\n").await; } @@ -73,6 +86,7 @@ pub async fn handle_connection( Err(_) => { let logger = RequestLogger::new(&stream, request); logger.log_error(59, "Invalid URL format"); + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); return send_response(&mut stream, "59 Bad Request\r\n").await; } }; @@ -85,6 +99,7 @@ pub async fn handle_connection( Ok(fp) => fp, Err(PathResolutionError::NotFound) => { logger.log_error(51, "File not found"); + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); return send_response(&mut stream, "51 Not found\r\n").await; } }; @@ -95,7 +110,7 @@ pub async fn handle_connection( Err(_) => { // This shouldn't happen since we check existence, but handle gracefully logger.log_error(51, "File not found"); - send_response(&mut stream, "51 Not found\r\n").await?; + let _ = send_response(&mut stream, "51 Not found\r\n").await; } } }, @@ -103,7 +118,7 @@ pub async fn handle_connection( // Read failed, check error type let request_str = String::from_utf8_lossy(&request_buf).trim().to_string(); let logger = RequestLogger::new(&stream, request_str); - + match e.kind() { tokio::io::ErrorKind::InvalidData => { logger.log_error(59, "Request too large"); @@ -114,6 +129,7 @@ pub async fn handle_connection( let _ = send_response(&mut stream, "59 Bad Request\r\n").await; } } + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); }, Err(_) => { // Timeout @@ -121,10 +137,12 @@ pub async fn handle_connection( let logger = RequestLogger::new(&stream, request_str); logger.log_error(41, "Server unavailable"); let _ = send_response(&mut stream, "41 Server unavailable\r\n").await; + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); return Ok(()); } } - + + ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed); Ok(()) }