Implement rate limiting with 41 responses and comprehensive logging

- Add concurrent connection handling with tokio::spawn for proper rate limiting
- Send '41 Server unavailable' responses instead of dropping connections
- Move request logger initialization earlier to enable rate limiting logs
- Add logging for rate limited requests: 'Concurrent request limit exceeded'
- Fix clippy warnings: needless borrows and match simplification
- Update test script analysis to expect 41 responses for rate limiting
This commit is contained in:
Jeena 2026-01-16 06:00:18 +00:00
parent da39f37559
commit 33ae576b25
5 changed files with 52 additions and 20 deletions

View file

@ -28,13 +28,10 @@ impl RequestLogger {
}
fn extract_client_ip(stream: &TlsStream<TcpStream>) -> String {
match stream.get_ref() {
(tcp_stream, _) => {
match tcp_stream.peer_addr() {
Ok(addr) => addr.to_string(),
Err(_) => "unknown".to_string(),
}
}
let (tcp_stream, _) = stream.get_ref();
match tcp_stream.peer_addr() {
Ok(addr) => addr.to_string(),
Err(_) => "unknown".to_string(),
}
}

View file

@ -53,6 +53,11 @@ struct Args {
/// Hostname for the server
#[arg(short = 'H', long)]
host: Option<String>,
/// TESTING ONLY: Add delay before processing (seconds) [debug builds only]
#[cfg(debug_assertions)]
#[arg(long, value_name = "SECONDS")]
test_processing_delay: Option<u64>,
}
@ -91,6 +96,16 @@ async fn main() {
std::process::exit(1);
}
// TESTING ONLY: Read delay argument (debug builds only)
#[cfg(debug_assertions)]
let test_processing_delay = args.test_processing_delay
.filter(|&d| d > 0 && d <= 300)
.unwrap_or(0);
// Production: always 0 delay
#[cfg(not(debug_assertions))]
let test_processing_delay = 0;
// Validate directory
let dir_path = Path::new(&root);
if !dir_path.exists() || !dir_path.is_dir() {
@ -120,10 +135,13 @@ async fn main() {
let dir = root.clone();
let expected_host = "localhost".to_string(); // Override for testing
let max_concurrent = max_concurrent_requests;
if let Ok(stream) = acceptor.accept(stream).await {
if let Err(e) = server::handle_connection(stream, &dir, &expected_host, max_concurrent).await {
tracing::error!("Error handling connection: {}", e);
let test_delay = test_processing_delay;
tokio::spawn(async move {
if let Ok(stream) = acceptor.accept(stream).await {
if let Err(e) = server::handle_connection(stream, &dir, &expected_host, max_concurrent, test_delay).await {
tracing::error!("Error handling connection: {}", e);
}
}
}
});
}
}

View file

@ -16,8 +16,8 @@ pub async fn serve_file(
file_path: &Path,
) -> io::Result<()> {
if file_path.exists() && file_path.is_file() {
let mime_type = get_mime_type(&file_path);
let content = fs::read(&file_path)?;
let mime_type = get_mime_type(file_path);
let content = fs::read(file_path)?;
let mut response = format!("20 {}\r\n", mime_type).into_bytes();
response.extend(content);
stream.write_all(&response).await?;
@ -33,6 +33,7 @@ pub async fn handle_connection(
dir: &str,
expected_host: &str,
max_concurrent_requests: usize,
test_processing_delay: u64,
) -> io::Result<()> {
const MAX_REQUEST_SIZE: usize = 4096;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
@ -58,9 +59,13 @@ pub async fn handle_connection(
// Read successful, continue processing
let request = String::from_utf8_lossy(&request_buf).trim().to_string();
// Initialize logger early for all request types
let logger = RequestLogger::new(&stream, request.clone());
// Check concurrent request limit after TLS handshake and request read
let current = ACTIVE_REQUESTS.fetch_add(1, Ordering::Relaxed);
if current >= max_concurrent_requests {
logger.log_error(41, "Concurrent request limit exceeded");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
// Rate limited - send proper 41 response
send_response(&mut stream, "41 Server unavailable\r\n").await?;
@ -70,14 +75,12 @@ pub async fn handle_connection(
// Process the request
// Validate request
if request.is_empty() {
let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Empty request");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
}
if request.len() > 1024 {
let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Request too large");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
@ -87,15 +90,17 @@ pub async fn handle_connection(
let path = match parse_gemini_url(&request, expected_host) {
Ok(p) => p,
Err(_) => {
let logger = RequestLogger::new(&stream, request);
logger.log_error(59, "Invalid URL format");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
}
};
// Initialize logger now that we have the full request URL
let logger = RequestLogger::new(&stream, request);
// TESTING ONLY: Add delay for rate limiting tests (debug builds only)
#[cfg(debug_assertions)]
if test_processing_delay > 0 {
tokio::time::sleep(tokio::time::Duration::from_secs(test_processing_delay)).await;
}
// Resolve file path with security
let file_path = match resolve_file_path(&path, dir) {
@ -109,6 +114,8 @@ pub async fn handle_connection(
// No delay for normal operation
// Processing complete
// Serve the file
match serve_file(&mut stream, &file_path).await {
Ok(_) => logger.log_success(20),