diff --git a/src/main.rs b/src/main.rs index d482e26..929700b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,7 +61,7 @@ async fn main() { } // Load and parse config - let config = match config::load_config(&config_path) { + let config = match config::load_config(config_path) { Ok(config) => config, Err(e) => { eprintln!("Error: Failed to parse config file '{}': {}", config_path, e); diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..894d59a --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,59 @@ +use std::path::{Path, PathBuf}; +use tempfile::TempDir; + +pub struct TestEnvironment { + pub temp_dir: TempDir, + pub config_path: PathBuf, + pub cert_path: PathBuf, + pub key_path: PathBuf, + pub content_path: PathBuf, + pub port: u16, +} + +pub fn setup_test_environment() -> TestEnvironment { + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("config.toml"); + let cert_path = temp_dir.path().join("cert.pem"); + let key_path = temp_dir.path().join("key.pem"); + let content_path = temp_dir.path().join("content"); + + // Create content directory and file + std::fs::create_dir(&content_path).unwrap(); + std::fs::write(content_path.join("test.gmi"), "# Test Gemini content\n").unwrap(); + + // Generate test certificates + generate_test_certificates(temp_dir.path()); + + // Use a unique port based on process ID to avoid conflicts + let port = 1967 + (std::process::id() % 1000) as u16; + + TestEnvironment { + temp_dir, + config_path, + cert_path, + key_path, + content_path, + port, + } +} + +fn generate_test_certificates(temp_dir: &Path) { + use std::process::Command; + + let cert_path = temp_dir.join("cert.pem"); + let key_path = temp_dir.join("key.pem"); + + let status = Command::new("openssl") + .args(&[ + "req", "-x509", "-newkey", "rsa:2048", + "-keyout", &key_path.to_string_lossy(), + "-out", &cert_path.to_string_lossy(), + "-days", "1", + "-nodes", + "-subj", "/CN=localhost" + ]) + .status() + .unwrap(); + + assert!(status.success(), "Failed to generate test certificates"); +} \ No newline at end of file diff --git a/tests/config_validation.rs b/tests/config_validation.rs index 5fd50bf..2a43719 100644 --- a/tests/config_validation.rs +++ b/tests/config_validation.rs @@ -1,5 +1,6 @@ +mod common; + use std::process::Command; -use std::env; #[test] fn test_missing_config_file() { @@ -17,83 +18,100 @@ fn test_missing_config_file() { #[test] fn test_missing_hostname() { - let temp_dir = env::temp_dir().join(format!("pollux_test_config_{}", std::process::id())); - std::fs::create_dir_all(&temp_dir).unwrap(); - let config_path = temp_dir.join("config.toml"); - std::fs::write(&config_path, r#" - root = "/tmp" - cert = "cert.pem" - key = "key.pem" - bind_host = "0.0.0.0" - "#).unwrap(); + let env = common::setup_test_environment(); + let config_content = format!(r#" + root = "{}" + cert = "{}" + key = "{}" + bind_host = "127.0.0.1" + "#, env.content_path.display(), env.cert_path.display(), env.key_path.display()); + std::fs::write(&env.config_path, config_content).unwrap(); let output = Command::new(env!("CARGO_BIN_EXE_pollux")) .arg("--config") - .arg(&config_path) + .arg(&env.config_path) .output() .unwrap(); assert!(!output.status.success()); let stderr = String::from_utf8(output.stderr).unwrap(); assert!(stderr.contains("'hostname' field is required")); - assert!(stderr.contains("hostname = \"your.domain.com\"")); - - // Cleanup - let _ = std::fs::remove_dir_all(&temp_dir); + assert!(stderr.contains("Add: hostname = \"your.domain.com\"")); } #[test] fn test_nonexistent_root_directory() { - let temp_dir = env::temp_dir().join(format!("pollux_test_config_{}", std::process::id())); - std::fs::create_dir_all(&temp_dir).unwrap(); - let config_path = temp_dir.join("config.toml"); - std::fs::write(&config_path, r#" + let env = common::setup_test_environment(); + let config_content = format!(r#" root = "/definitely/does/not/exist" - cert = "cert.pem" - key = "key.pem" + cert = "{}" + key = "{}" hostname = "example.com" - bind_host = "0.0.0.0" - "#).unwrap(); + bind_host = "127.0.0.1" + "#, env.cert_path.display(), env.key_path.display()); + std::fs::write(&env.config_path, config_content).unwrap(); let output = Command::new(env!("CARGO_BIN_EXE_pollux")) .arg("--config") - .arg(config_path) + .arg(&env.config_path) .output() .unwrap(); - // Cleanup - let _ = std::fs::remove_dir_all(&temp_dir); - assert!(!output.status.success()); let stderr = String::from_utf8(output.stderr).unwrap(); - assert!(stderr.contains("Root directory '/definitely/does/not/exist' does not exist")); - assert!(stderr.contains("Create the directory and add your Gemini files")); + assert!(stderr.contains("Error: Root directory '/definitely/does/not/exist' does not exist")); + assert!(stderr.contains("Create the directory and add your Gemini files (.gmi, .txt, images)")); } #[test] fn test_missing_certificate_file() { - let temp_dir = env::temp_dir().join(format!("pollux_test_config_{}", std::process::id())); - std::fs::create_dir_all(&temp_dir).unwrap(); - let config_path = temp_dir.join("config.toml"); - std::fs::write(&config_path, r#" - root = "/tmp" + let env = common::setup_test_environment(); + let config_content = format!(r#" + root = "{}" cert = "/nonexistent/cert.pem" - key = "key.pem" + key = "{}" hostname = "example.com" - bind_host = "0.0.0.0" - "#).unwrap(); + bind_host = "127.0.0.1" + "#, env.content_path.display(), env.key_path.display()); + std::fs::write(&env.config_path, config_content).unwrap(); let output = Command::new(env!("CARGO_BIN_EXE_pollux")) .arg("--config") - .arg(&config_path) + .arg(&env.config_path) .output() .unwrap(); assert!(!output.status.success()); let stderr = String::from_utf8(output.stderr).unwrap(); - assert!(stderr.contains("Certificate file '/nonexistent/cert.pem' does not exist")); - assert!(stderr.contains("Generate or obtain TLS certificates")); + assert!(stderr.contains("Error: Certificate file '/nonexistent/cert.pem' does not exist")); + assert!(stderr.contains("Generate or obtain TLS certificates for your domain")); +} - // Cleanup - let _ = std::fs::remove_dir_all(&temp_dir); +#[test] +fn test_valid_config_startup() { + let env = common::setup_test_environment(); + let config_content = format!(r#" + root = "{}" + cert = "{}" + key = "{}" + hostname = "localhost" + bind_host = "127.0.0.1" + port = {} + "#, env.content_path.display(), env.cert_path.display(), env.key_path.display(), env.port); + std::fs::write(&env.config_path, config_content).unwrap(); + + let mut server_process = Command::new(env!("CARGO_BIN_EXE_pollux")) + .arg("--config") + .arg(&env.config_path) + .spawn() + .unwrap(); + + // Wait for server to start + std::thread::sleep(std::time::Duration::from_millis(500)); + + // Check server is still running (didn't exit with error) + assert!(server_process.try_wait().unwrap().is_none(), "Server should still be running with valid config"); + + // Kill server + server_process.kill().unwrap(); } \ No newline at end of file diff --git a/tests/gemini_test_client.py b/tests/gemini_test_client.py index b9b3975..351715f 100755 --- a/tests/gemini_test_client.py +++ b/tests/gemini_test_client.py @@ -1,195 +1,71 @@ #!/usr/bin/env python3 """ -Gemini Test Client +Simple Gemini Test Client -A simple Gemini protocol client for testing Gemini servers. -Used by integration tests to validate server behavior. +Makes a single Gemini request and prints the status line. +Used by integration tests for rate limiting validation. -Usage: - python3 tests/gemini_test_client.py --url gemini://example.com/ --timeout 10 +Usage: python3 tests/gemini_test_client.py gemini://host:port/path """ -import argparse +import sys import socket import ssl -import time -import multiprocessing -from concurrent.futures import ProcessPoolExecutor, as_completed -def parse_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser(description='Test Gemini rate limiting with concurrent requests') - parser.add_argument('--limit', type=int, default=3, - help='Number of concurrent requests to send (default: 3)') - parser.add_argument('--host', default='localhost', - help='Server host (default: localhost)') - parser.add_argument('--port', type=int, default=1965, - help='Server port (default: 1965)') - parser.add_argument('--delay', type=float, default=0.1, - help='Delay between request start and connection close (default: 0.1s)') - parser.add_argument('--timeout', type=float, default=5.0, - help='Socket timeout in seconds (default: 5.0)') - parser.add_argument('--url', default='gemini://localhost/big-file.mkv', - help='Gemini URL to request (default: gemini://localhost/big-file.mkv)') - - args = parser.parse_args() - - # Validation - if args.limit < 1: - parser.error("Limit must be at least 1") - if args.limit > 10000: - parser.error("Limit too high (max 10000 for safety)") - if args.delay < 0: - parser.error("Delay must be non-negative") - if args.timeout <= 0: - parser.error("Timeout must be positive") - - return args - -def send_gemini_request(host, port, url, delay, timeout): - """Send one Gemini request with proper error handling""" +def main(): + if len(sys.argv) != 2: + print("Usage: python3 gemini_test_client.py ", file=sys.stderr) + sys.exit(1) + + url = sys.argv[1] + + # Parse URL (basic parsing) + if not url.startswith('gemini://'): + print("Error: URL must start with gemini://", file=sys.stderr) + sys.exit(1) + + url_parts = url[9:].split('/', 1) # Remove gemini:// + host_port = url_parts[0] + path = '/' + url_parts[1] if len(url_parts) > 1 else '/' + + if ':' in host_port: + host, port = host_port.rsplit(':', 1) + port = int(port) + else: + host = host_port + port = 1965 + try: - # Create SSL context + # Create SSL connection context = ssl.create_default_context() context.check_hostname = False context.verify_mode = ssl.CERT_NONE - - # Connect with timeout - sock = socket.create_connection((host, port), timeout=timeout) + + sock = socket.create_connection((host, port), timeout=5.0) ssl_sock = context.wrap_socket(sock, server_hostname=host) - + # Send request - request = f"{url}\r\n".encode('utf-8') - ssl_sock.send(request) - - # Read response with timeout - ssl_sock.settimeout(timeout) - response = ssl_sock.recv(1024) - - if not response: - return "Error: Empty response" - - status = response.decode('utf-8', errors='ignore').split('\r\n')[0] - - # Keep connection alive briefly if requested - if delay > 0: - time.sleep(delay) - + request = f"{url}\r\n" + ssl_sock.send(request.encode('utf-8')) + + # Read response header + response = b'' + while b'\r\n' not in response and len(response) < 1024: + data = ssl_sock.recv(1) + if not data: + break + response += data + ssl_sock.close() - return status - - except socket.timeout: - return "Error: Timeout" - except ConnectionRefusedError: - return "Error: Connection refused" + + if response: + status_line = response.decode('utf-8', errors='ignore').split('\r\n')[0] + print(status_line) + else: + print("Error: No response") + except Exception as e: - return f"Error: {e}" - -def main(): - """Run concurrent requests""" - args = parse_args() - - if args.limit == 1: - print("Testing single request (debug mode)...") - start_time = time.time() - result = send_gemini_request(args.host, args.port, args.url, args.delay, args.timeout) - end_time = time.time() - duration = end_time - start_time - print(f"Result: {result}") - print(".2f") - return - - print(f"Testing rate limiting with {args.limit} concurrent requests (using multiprocessing)...") - print(f"Server: {args.host}:{args.port}") - print(f"URL: {args.url}") - print(f"Delay: {args.delay}s, Timeout: {args.timeout}s") - print() - - start_time = time.time() - - # Use ProcessPoolExecutor for true parallelism (bypasses GIL) - results = [] - max_workers = min(args.limit, multiprocessing.cpu_count() * 4) # Limit workers to avoid system overload - - with ProcessPoolExecutor(max_workers=max_workers) as executor: - futures = [ - executor.submit(send_gemini_request, args.host, args.port, - args.url, args.delay, args.timeout) - for _ in range(args.limit) - ] - - for future in as_completed(futures): - results.append(future.result()) - - elapsed = time.time() - start_time - - # Analyze results - status_counts = {} - connection_refused = 0 - timeouts = 0 - other_errors = [] - - for result in results: - if "Connection refused" in result: - connection_refused += 1 - elif "Timeout" in result: - timeouts += 1 - elif result.startswith("Error"): - other_errors.append(result) - else: - status_counts[result] = status_counts.get(result, 0) + 1 - - # Print results - print("Results:") - for status, count in sorted(status_counts.items()): - print(f" {status}: {count}") - if connection_refused > 0: - print(f" Connection refused: {connection_refused} (server overloaded)") - if timeouts > 0: - print(f" Timeouts: {timeouts} (server unresponsive)") - if other_errors: - print(f" Other errors: {len(other_errors)}") - for error in other_errors[:3]: - print(f" {error}") - if len(other_errors) > 3: - print(f" ... and {len(other_errors) - 3} more") - - print() - print(".2f") - - # Success criteria for rate limiting - success_20 = status_counts.get("20 application/octet-stream", 0) - rate_limited_41 = status_counts.get("41 Server unavailable", 0) - total_successful = success_20 + rate_limited_41 + connection_refused - total_processed = total_successful + timeouts - - print(f"\nAnalysis:") - print(f" Total requests sent: {args.limit}") - print(f" Successfully processed: {total_successful}") - print(f" Timeouts (server unresponsive): {timeouts}") - - if args.limit == 1: - # Single request should succeed - if success_20 == 1 and timeouts == 0: - print("✅ PASS: Single request works correctly") - else: - print("❌ FAIL: Single request failed") - elif rate_limited_41 > 0 and success_20 > 0: - # We have both successful responses and 41 rate limited responses - print("✅ PASS: Rate limiting detected!") - print(f" {success_20} requests succeeded") - print(f" {rate_limited_41} requests rate limited with 41 response") - print(" Mixed results indicate rate limiting is working correctly") - elif success_20 == args.limit and timeouts == 0: - # All requests succeeded - print("⚠️ All requests succeeded - rate limiting may not be triggered") - print(" This could mean:") - print(" - Requests are not truly concurrent") - print(" - Processing is too fast for overlap") - print(" - Need longer delays or more concurrent requests") - else: - print("❓ UNCLEAR: Check server logs and test parameters") - print(" May need to adjust --limit, delays, or server configuration") + print(f"Error: {e}") if __name__ == '__main__': main() \ No newline at end of file diff --git a/tests/rate_limiting.rs b/tests/rate_limiting.rs index c7ff12d..0d58916 100644 --- a/tests/rate_limiting.rs +++ b/tests/rate_limiting.rs @@ -1,37 +1,10 @@ -use std::process::Command; +mod common; -struct TestEnvironment { - temp_dir: std::path::PathBuf, - config_file: std::path::PathBuf, - content_file: std::path::PathBuf, - port: u16, -} +#[test] +fn test_rate_limiting_with_concurrent_requests() { + let env = common::setup_test_environment(); -impl Drop for TestEnvironment { - fn drop(&mut self) { - let _ = std::fs::remove_dir_all(&self.temp_dir); - } -} - -fn setup_test_environment() -> Result> { - use std::env; - - // Create unique temp directory for this test - let temp_dir = env::temp_dir().join(format!("pollux_test_{}", std::process::id())); - std::fs::create_dir_all(&temp_dir)?; - - // Generate test certificates - generate_test_certificates(&temp_dir)?; - - // Create test content file - let content_file = temp_dir.join("test.gmi"); - std::fs::write(&content_file, "# Test Gemini content\n")?; - - // Use a unique port based on process ID to avoid conflicts - let port = 1967 + (std::process::id() % 1000) as u16; - - // Create config file - let config_file = temp_dir.join("config.toml"); + // Create config with rate limiting enabled let config_content = format!(r#" root = "{}" cert = "{}" @@ -40,53 +13,51 @@ fn setup_test_environment() -> Result Result<(), Box> { - use std::process::Command; - - let cert_path = temp_dir.join("cert.pem"); - let key_path = temp_dir.join("key.pem"); - - let status = Command::new("openssl") - .args(&[ - "req", "-x509", "-newkey", "rsa:2048", - "-keyout", &key_path.to_string_lossy(), - "-out", &cert_path.to_string_lossy(), - "-days", "1", - "-nodes", - "-subj", "/CN=localhost" - ]) - .status()?; - - if !status.success() { - return Err("Failed to generate test certificates with openssl".into()); + // Start server binary with test delay to simulate processing time + let mut server_process = std::process::Command::new(env!("CARGO_BIN_EXE_pollux")) + .arg("--config") + .arg(&env.config_path) + .arg("--test-processing-delay") + .arg("1") // 1 second delay per request + .spawn() + .expect("Failed to start server"); + + // Wait for server to start + std::thread::sleep(std::time::Duration::from_millis(500)); + + // Spawn 5 concurrent client processes + let mut handles = vec![]; + for _ in 0..5 { + let url = format!("gemini://localhost:{}/test.gmi", env.port); + let handle = std::thread::spawn(move || { + std::process::Command::new("python3") + .arg("tests/gemini_test_client.py") + .arg(url) + .output() + }); + handles.push(handle); } - - Ok(()) -} -#[test] -fn test_rate_limiting_with_concurrent_requests() { - // For now, skip the complex concurrent testing - // The test infrastructure is in place, but full integration testing - // requires more robust isolation and timing controls - println!("Skipping rate limiting integration test - infrastructure ready for future implementation"); -} + // Collect results + let mut results = vec![]; + for handle in handles { + let output = handle.join().unwrap().unwrap(); + let status = String::from_utf8(output.stdout).unwrap(); + results.push(status.trim().to_string()); + } -fn python_available() -> bool { - std::process::Command::new("python3") - .arg("--version") - .output() - .map(|output| output.status.success()) - .unwrap_or(false) + // Kill server + let _ = server_process.kill(); + + // Analyze results + let success_count = results.iter().filter(|r| r.starts_with("20")).count(); + let rate_limited_count = results.iter().filter(|r| r.starts_with("41")).count(); + + // Validation + assert!(success_count >= 1, "At least 1 request should succeed, got results: {:?}", results); + assert!(rate_limited_count >= 1, "At least 1 request should be rate limited, got results: {:?}", results); + assert_eq!(success_count + rate_limited_count, 5, "All requests should get valid responses, got results: {:?}", results); } \ No newline at end of file