diff --git a/src/main.rs b/src/main.rs index 7ac8c44..be856cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,26 +34,6 @@ struct Args { #[arg(short = 'C', long)] config: Option, - /// Directory to serve files from - #[arg(short, long)] - root: Option, - - /// Path to certificate file - #[arg(short, long)] - cert: Option, - - /// Path to private key file - #[arg(short, long)] - key: Option, - - /// Port to listen on - #[arg(short, long)] - port: Option, - - /// Hostname for the server - #[arg(short = 'H', long)] - host: Option, - /// TESTING ONLY: Add delay before processing (seconds) [debug builds only] #[cfg(debug_assertions)] #[arg(long, value_name = "SECONDS")] @@ -82,12 +62,12 @@ async fn main() { let log_level = config.log_level.as_deref().unwrap_or("info"); init_logging(log_level); - // Merge config with args (args take precedence) - let root = args.root.or(config.root).expect("root is required"); - let cert_path = args.cert.or(config.cert).expect("cert is required"); - let key_path = args.key.or(config.key).expect("key is required"); - let host = args.host.or(config.host).unwrap_or_else(|| "0.0.0.0".to_string()); - let port = args.port.or(config.port).unwrap_or(1965); + // Load configuration from file only + let root = config.root.expect("root is required"); + let cert_path = config.cert.expect("cert is required"); + let key_path = config.key.expect("key is required"); + let host = config.host.unwrap_or_else(|| "0.0.0.0".to_string()); + let port = config.port.unwrap_or(1965); // Validate max concurrent requests let max_concurrent_requests = config.max_concurrent_requests.unwrap_or(1000); @@ -139,7 +119,7 @@ async fn main() { let test_delay = test_processing_delay; tokio::spawn(async move { if let Ok(stream) = acceptor.accept(stream).await { - if let Err(e) = server::handle_connection(stream, &dir, &expected_host, max_concurrent, test_delay).await { + if let Err(e) = server::handle_connection(stream, &dir, &expected_host, port, max_concurrent, test_delay).await { tracing::error!("Error handling connection: {}", e); } } diff --git a/src/request.rs b/src/request.rs index 586cc73..9a57584 100644 --- a/src/request.rs +++ b/src/request.rs @@ -6,14 +6,34 @@ pub enum PathResolutionError { NotFound, } -pub fn parse_gemini_url(request: &str, expected_host: &str) -> Result { +pub fn parse_gemini_url(request: &str, expected_host: &str, expected_port: u16) -> Result { if let Some(url) = request.strip_prefix("gemini://") { - let host_end = url.find('/').unwrap_or(url.len()); - let host = &url[..host_end]; + let host_port_end = url.find('/').unwrap_or(url.len()); + let host_port = &url[..host_port_end]; + + // Parse host and port + let (host, port_str) = if let Some(colon_pos) = host_port.find(':') { + let host = &host_port[..colon_pos]; + let port_str = &host_port[colon_pos + 1..]; + (host, Some(port_str)) + } else { + (host_port, None) + }; + + // Validate host if host != expected_host { return Err(()); // Hostname mismatch } - let path = if host_end < url.len() { &url[host_end..] } else { "/" }; + + // Validate port + let port = port_str + .and_then(|p| p.parse::().ok()) + .unwrap_or(1965); + if port != expected_port { + return Err(()); // Port mismatch + } + + let path = if host_port_end < url.len() { &url[host_port_end..] } else { "/" }; Ok(path.trim().to_string()) } else { Err(()) @@ -69,18 +89,18 @@ mod tests { #[test] fn test_parse_gemini_url_valid() { - assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/", "gemini.jeena.net"), Ok("/".to_string())); - assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/posts/test", "gemini.jeena.net"), Ok("/posts/test".to_string())); + assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/", "gemini.jeena.net", 1965), Ok("/".to_string())); + assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/posts/test", "gemini.jeena.net", 1965), Ok("/posts/test".to_string())); } #[test] fn test_parse_gemini_url_invalid_host() { - assert!(parse_gemini_url("gemini://foo.com/", "gemini.jeena.net").is_err()); + assert!(parse_gemini_url("gemini://foo.com/", "gemini.jeena.net", 1965).is_err()); } #[test] fn test_parse_gemini_url_no_prefix() { - assert!(parse_gemini_url("http://gemini.jeena.net/", "gemini.jeena.net").is_err()); + assert!(parse_gemini_url("http://gemini.jeena.net/", "gemini.jeena.net", 1965).is_err()); } #[test] diff --git a/src/server.rs b/src/server.rs index d98e311..a2937a8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -41,10 +41,11 @@ pub async fn handle_connection( mut stream: TlsStream, dir: &str, expected_host: &str, + expected_port: u16, max_concurrent_requests: usize, test_processing_delay: u64, ) -> io::Result<()> { - const MAX_REQUEST_SIZE: usize = 4096; + const MAX_REQUEST_SIZE: usize = 1026; const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); let mut request_buf = Vec::new(); @@ -96,7 +97,7 @@ pub async fn handle_connection( } // Parse Gemini URL - let path = match parse_gemini_url(&request, expected_host) { + let path = match parse_gemini_url(&request, expected_host, expected_port) { Ok(p) => p, Err(_) => { logger.log_error(59, "Invalid URL format");