feat: Implement virtual hosting for multi-domain Gemini server

- Add hostname-based request routing for multiple capsules per server
- Parse virtual host configs from TOML sections ([hostname])
- Implement per-host certificate and content isolation
- Add comprehensive virtual host testing and validation
- Update docs and examples for multi-host deployments

This enables Pollux to serve multiple Gemini domains from one instance,
providing the foundation for multi-tenant Gemini hosting.
This commit is contained in:
Jeena 2026-01-22 02:38:09 +00:00
parent c193d831ed
commit 0459cb6220
22 changed files with 2296 additions and 406 deletions

View file

@ -1,21 +1,186 @@
use serde::Deserialize;
use std::collections::HashMap;
use toml::Value;
#[derive(Deserialize)]
#[derive(Debug)]
pub struct Config {
pub root: Option<String>,
pub cert: Option<String>,
pub key: Option<String>,
// Global defaults (optional)
pub bind_host: Option<String>,
pub hostname: Option<String>,
pub port: Option<u16>,
pub log_level: Option<String>,
pub max_concurrent_requests: Option<usize>,
// Per-hostname configurations
pub hosts: HashMap<String, HostConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct HostConfig {
pub root: String,
pub cert: String,
pub key: String,
#[serde(default)]
#[allow(dead_code)]
pub port: Option<u16>, // override global port
#[serde(default)]
#[allow(dead_code)]
pub log_level: Option<String>, // override global log level
}
pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
let toml_value: Value = toml::from_str(&content)?;
// Extract global settings
let bind_host = extract_string(&toml_value, "bind_host");
let port = extract_u16(&toml_value, "port");
let log_level = extract_string(&toml_value, "log_level");
let max_concurrent_requests = extract_usize(&toml_value, "max_concurrent_requests");
// Extract host configurations
let mut hosts = HashMap::new();
if let Some(table) = toml_value.as_table() {
for (key, value) in table {
// Skip global config keys
if matches!(key.as_str(), "bind_host" | "port" | "log_level" | "max_concurrent_requests") {
continue;
}
// This should be a hostname section
if let Some(host_table) = value.as_table() {
let root = extract_required_string(host_table, "root", key)?;
let cert = extract_required_string(host_table, "cert", key)?;
let key_path = extract_required_string(host_table, "key", key)?;
let port_override = extract_u16_from_table(host_table, "port");
let log_level_override = extract_string_from_table(host_table, "log_level");
// Validate hostname
if !is_valid_hostname(key) {
return Err(format!("Invalid hostname '{}'. Hostnames must be valid DNS names.", key).into());
}
// Validate that root directory exists
if !std::path::Path::new(&root).exists() {
return Err(format!("Error for host '{}': Root directory '{}' does not exist\nCreate the directory and add your Gemini files (.gmi, .txt, images)", key, root).into());
}
// Validate that certificate file exists
if !std::path::Path::new(&cert).exists() {
return Err(format!("Error for host '{}': Certificate file '{}' does not exist\nGenerate or obtain TLS certificates for your domain", key, cert).into());
}
// Validate that key file exists
if !std::path::Path::new(&key_path).exists() {
return Err(format!("Error for host '{}': Key file '{}' does not exist\nGenerate or obtain TLS certificates for your domain", key, key_path).into());
}
let host_config = HostConfig {
root,
cert,
key: key_path,
port: port_override,
log_level: log_level_override,
};
hosts.insert(key.clone(), host_config);
}
}
}
// Validate that we have at least one host configured
if hosts.is_empty() {
return Err("No host configurations found. Add at least one [hostname] section.".into());
}
Ok(Config {
bind_host,
port,
log_level,
max_concurrent_requests,
hosts,
})
}
fn extract_string(value: &Value, key: &str) -> Option<String> {
value.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
}
fn extract_string_from_table(table: &toml::map::Map<String, Value>, key: &str) -> Option<String> {
table.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
}
fn extract_u16(value: &Value, key: &str) -> Option<u16> {
value.get(key).and_then(|v| v.as_integer()).and_then(|i| u16::try_from(i).ok())
}
fn extract_u16_from_table(table: &toml::map::Map<String, Value>, key: &str) -> Option<u16> {
table.get(key).and_then(|v| v.as_integer()).and_then(|i| u16::try_from(i).ok())
}
fn extract_usize(value: &Value, key: &str) -> Option<usize> {
value.get(key).and_then(|v| v.as_integer()).and_then(|i| usize::try_from(i).ok())
}
fn extract_required_string(table: &toml::map::Map<String, Value>, key: &str, section: &str) -> Result<String, Box<dyn std::error::Error>> {
table.get(key)
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| format!("Missing required field '{}' in [{}] section", key, section).into())
}
/// Validate that a hostname is a proper DNS name
fn is_valid_hostname(hostname: &str) -> bool {
if hostname.is_empty() || hostname.len() > 253 {
return false;
}
// Allow localhost for testing
if hostname == "localhost" {
return true;
}
// Basic validation: no control characters, no spaces, reasonable characters
for ch in hostname.chars() {
if ch.is_control() || ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
return false;
}
}
// Must contain at least one dot (be a domain)
if !hostname.contains('.') {
return false;
}
// Check each label (parts separated by dots)
for label in hostname.split('.') {
if label.is_empty() || label.len() > 63 {
return false;
}
// Labels can contain letters, digits, and hyphens
// Must start and end with alphanumeric characters
let chars: Vec<char> = label.chars().collect();
if chars.is_empty() {
return false;
}
if !chars[0].is_alphanumeric() {
return false;
}
if chars.len() > 1 && !chars[chars.len() - 1].is_alphanumeric() {
return false;
}
for &ch in &chars {
if !ch.is_alphanumeric() && ch != '-' {
return false;
}
}
}
true
}
#[cfg(test)]
@ -25,52 +190,161 @@ mod tests {
use tempfile::TempDir;
#[test]
fn test_load_config_valid() {
fn test_is_valid_hostname() {
// Valid hostnames
assert!(is_valid_hostname("example.com"));
assert!(is_valid_hostname("sub.example.com"));
assert!(is_valid_hostname("localhost"));
assert!(is_valid_hostname("my-host-123.example.org"));
// Invalid hostnames
assert!(!is_valid_hostname(""));
assert!(!is_valid_hostname("-invalid.com"));
assert!(!is_valid_hostname("invalid-.com"));
assert!(!is_valid_hostname("invalid..com"));
assert!(!is_valid_hostname("invalid.com."));
assert!(!is_valid_hostname("inval!d.com"));
assert!(is_valid_hostname("too.long.label.that.exceeds.sixty.three.characters.example.com"));
}
#[test]
fn test_load_config_valid_single_host() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = r#"
root = "/path/to/root"
cert = "cert.pem"
key = "key.pem"
bind_host = "0.0.0.0"
hostname = "example.com"
// Create the root directory and cert/key files
let root_dir = temp_dir.path().join("root");
fs::create_dir(&root_dir).unwrap();
let cert_path = temp_dir.path().join("cert.pem");
let key_path = temp_dir.path().join("key.pem");
fs::write(&cert_path, "dummy cert").unwrap();
fs::write(&key_path, "dummy key").unwrap();
let content = format!(r#"
["example.com"]
root = "{}"
cert = "{}"
key = "{}"
port = 1965
log_level = "info"
"#;
"#, root_dir.display(), cert_path.display(), key_path.display());
fs::write(&config_path, content).unwrap();
let config = load_config(config_path.to_str().unwrap()).unwrap();
assert_eq!(config.root, Some("/path/to/root".to_string()));
assert_eq!(config.cert, Some("cert.pem".to_string()));
assert_eq!(config.key, Some("key.pem".to_string()));
assert_eq!(config.bind_host, Some("0.0.0.0".to_string()));
assert_eq!(config.hostname, Some("example.com".to_string()));
assert_eq!(config.port, Some(1965));
assert_eq!(config.log_level, Some("info".to_string()));
assert_eq!(config.max_concurrent_requests, None); // Default
assert_eq!(config.hosts.len(), 1);
assert!(config.hosts.contains_key("example.com"));
let host_config = &config.hosts["example.com"];
assert_eq!(host_config.root, root_dir.to_str().unwrap());
assert_eq!(host_config.cert, cert_path.to_str().unwrap());
assert_eq!(host_config.key, key_path.to_str().unwrap());
assert_eq!(host_config.port, Some(1965));
assert_eq!(host_config.log_level, Some("info".to_string()));
}
#[test]
fn test_load_config_with_max_concurrent_requests() {
fn test_load_config_valid_multiple_hosts() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
// Create directories and cert files for both hosts
let site1_root = temp_dir.path().join("site1");
let site2_root = temp_dir.path().join("site2");
fs::create_dir(&site1_root).unwrap();
fs::create_dir(&site2_root).unwrap();
let site1_cert = temp_dir.path().join("site1.crt");
let site1_key = temp_dir.path().join("site1.key");
let site2_cert = temp_dir.path().join("site2.crt");
let site2_key = temp_dir.path().join("site2.key");
fs::write(&site1_cert, "dummy cert 1").unwrap();
fs::write(&site1_key, "dummy key 1").unwrap();
fs::write(&site2_cert, "dummy cert 2").unwrap();
fs::write(&site2_key, "dummy key 2").unwrap();
let content = format!(r#"
["site1.com"]
root = "{}"
cert = "{}"
key = "{}"
["site2.org"]
root = "{}"
cert = "{}"
key = "{}"
port = 1966
"#, site1_root.display(), site1_cert.display(), site1_key.display(),
site2_root.display(), site2_cert.display(), site2_key.display());
fs::write(&config_path, content).unwrap();
let config = load_config(config_path.to_str().unwrap()).unwrap();
assert_eq!(config.hosts.len(), 2);
assert!(config.hosts.contains_key("site1.com"));
assert!(config.hosts.contains_key("site2.org"));
let site1 = &config.hosts["site1.com"];
assert_eq!(site1.root, site1_root.to_str().unwrap());
assert_eq!(site1.port, None);
let site2 = &config.hosts["site2.org"];
assert_eq!(site2.root, site2_root.to_str().unwrap());
assert_eq!(site2.port, Some(1966));
}
#[test]
fn test_load_config_no_hosts() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = r#"
root = "/path/to/root"
max_concurrent_requests = 500
bind_host = "127.0.0.1"
port = 1965
"#;
fs::write(&config_path, content).unwrap();
let config = load_config(config_path.to_str().unwrap()).unwrap();
assert_eq!(config.max_concurrent_requests, Some(500));
let result = load_config(config_path.to_str().unwrap());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("No host configurations found"));
}
#[test]
fn test_load_config_invalid() {
fn test_load_config_invalid_hostname() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = "invalid toml";
let content = r#"
["-invalid.com"]
root = "/some/path"
cert = "cert.pem"
key = "key.pem"
"#;
fs::write(&config_path, content).unwrap();
let result = load_config(config_path.to_str().unwrap());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid hostname"));
}
#[test]
fn test_load_config_invalid_toml() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = "invalid toml content";
fs::write(&config_path, content).unwrap();
assert!(load_config(config_path.to_str().unwrap()).is_err());
}
#[test]
fn test_load_config_missing_required_fields() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let content = r#"
["example.com"]
root = "/path"
# missing cert and key
"#;
fs::write(&config_path, content).unwrap();
// Config parsing will fail if required fields are missing
assert!(load_config(config_path.to_str().unwrap()).is_err());
}
}

View file

@ -1,4 +1,5 @@
use tokio::net::TcpStream;
use std::time::Instant;
use tokio_rustls::server::TlsStream;
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::FormatFields;
@ -33,23 +34,24 @@ where
}
}
#[allow(dead_code)]
pub struct RequestLogger {
client_ip: String,
request_url: String,
start_time: Instant,
}
impl RequestLogger {
pub fn new(stream: &TlsStream<TcpStream>, request_url: String) -> Self {
let client_ip = extract_client_ip(stream);
#[allow(dead_code)]
pub fn new(_stream: &TlsStream<TcpStream>, request_url: String) -> Self {
Self {
client_ip,
request_url,
start_time: Instant::now(),
}
}
#[allow(dead_code)]
pub fn log_error(self, status_code: u8, error_message: &str) {
let level = match status_code {
41 | 51 => tracing::Level::WARN,
@ -60,8 +62,8 @@ impl RequestLogger {
let request_path = self.request_url.strip_prefix("gemini://localhost").unwrap_or(&self.request_url);
match level {
tracing::Level::WARN => tracing::warn!("{} \"{}\" {} \"{}\"", self.client_ip, request_path, status_code, error_message),
tracing::Level::ERROR => tracing::error!("{} \"{}\" {} \"{}\"", self.client_ip, request_path, status_code, error_message),
tracing::Level::WARN => tracing::warn!("unknown \"{}\" {} \"{}\"", request_path, status_code, error_message),
tracing::Level::ERROR => tracing::error!("unknown \"{}\" {} \"{}\"", request_path, status_code, error_message),
_ => {}
}
}
@ -69,6 +71,7 @@ impl RequestLogger {
}
#[allow(dead_code)]
fn extract_client_ip(stream: &TlsStream<TcpStream>) -> String {
let (tcp_stream, _) = stream.get_ref();
match tcp_stream.peer_addr() {

View file

@ -12,15 +12,40 @@ use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use logging::init_logging;
fn print_startup_info(host: &str, port: u16, root: &str, cert: &str, key: &str, log_level: Option<&str>, max_concurrent: usize) {
println!("Pollux Gemini Server");
println!("Listening on: {}:{}", host, port);
println!("Serving: {}", root);
println!("Certificate: {}", cert);
println!("Key: {}", key);
println!("Max concurrent requests: {}", max_concurrent);
if let Some(level) = log_level {
println!("Log level: {}", level);
fn create_tls_config(hosts: &std::collections::HashMap<String, config::HostConfig>) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error>> {
// For Phase 3, we'll use the first host's certificate for all connections
// TODO: Phase 4 could implement proper SNI-based certificate selection
let first_host = hosts.values().next().ok_or("No hosts configured")?;
let certs = tls::load_certs(&first_host.cert)?;
let key = tls::load_private_key(&first_host.key)?;
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(Arc::new(config))
}
fn print_startup_info(config: &config::Config, hosts: &std::collections::HashMap<String, config::HostConfig>) {
println!("Pollux Gemini Server (Virtual Host Mode)");
println!("Configured hosts:");
for (hostname, host_config) in hosts {
println!(" {} -> {}", hostname, host_config.root);
}
println!("Global settings:");
if let Some(ref host) = config.bind_host {
println!(" Bind host: {}", host);
}
if let Some(port) = config.port {
println!(" Default port: {}", port);
}
if let Some(ref level) = config.log_level {
println!(" Log level: {}", level);
}
if let Some(max_concurrent) = config.max_concurrent_requests {
println!(" Max concurrent requests: {}", max_concurrent);
}
println!(); // Add spacing before connections start
}
@ -30,33 +55,30 @@ fn print_startup_info(host: &str, port: u16, root: &str, cert: &str, key: &str,
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Path to config file
#[arg(short = 'C', long)]
/// Path to configuration file
#[arg(short, long)]
config: Option<String>,
/// TESTING ONLY: Add delay before processing (seconds) [debug builds only]
#[cfg(debug_assertions)]
#[arg(long, value_name = "SECONDS")]
/// Processing delay for testing (in milliseconds)
#[arg(long, hide = true)]
test_processing_delay: Option<u64>,
}
#[tokio::main]
async fn main() {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// Load config
let config_path = args.config.as_deref().unwrap_or("/etc/pollux/config.toml");
// Check if config file exists
if !std::path::Path::new(&config_path).exists() {
eprintln!("Error: Config file '{}' not found", config_path);
eprintln!("Create the config file with required fields:");
eprintln!(" root = \"/path/to/gemini/content\"");
eprintln!(" cert = \"/path/to/certificate.pem\"");
eprintln!(" key = \"/path/to/private-key.pem\"");
eprintln!(" bind_host = \"0.0.0.0\"");
eprintln!(" hostname = \"your.domain.com\"");
eprintln!("Create the config file with virtual host sections like:");
eprintln!("[example.com]");
eprintln!("root = \"/srv/gemini/example.com/gemini/\"");
eprintln!("cert = \"/srv/gemini/example.com/tls/fullchain.pem\"");
eprintln!("key = \"/srv/gemini/example.com/tls/privkey.pem\"");
std::process::exit(1);
}
@ -65,90 +87,61 @@ async fn main() {
Ok(config) => config,
Err(e) => {
eprintln!("Error: Failed to parse config file '{}': {}", config_path, e);
eprintln!("Check the TOML syntax and ensure all values are properly quoted.");
eprintln!("Check the TOML syntax and ensure host sections are properly formatted.");
std::process::exit(1);
}
};
// Validate required fields
if config.root.is_none() {
eprintln!("Error: 'root' field is required in config file");
eprintln!("Add: root = \"/path/to/gemini/content\"");
std::process::exit(1);
// Validate host configurations
for (hostname, host_config) in &config.hosts {
// Validate root directory exists and is readable
let root_path = Path::new(&host_config.root);
if !root_path.exists() {
eprintln!("Error: Root directory '{}' for host '{}' does not exist", host_config.root, hostname);
eprintln!("Create the directory and add your Gemini files (.gmi, .txt, images)");
std::process::exit(1);
}
if !root_path.is_dir() {
eprintln!("Error: Root path '{}' for host '{}' is not a directory", host_config.root, hostname);
eprintln!("The 'root' field must point to a directory containing your content");
std::process::exit(1);
}
if let Err(e) = std::fs::read_dir(root_path) {
eprintln!("Error: Cannot read root directory '{}' for host '{}': {}", host_config.root, hostname, e);
eprintln!("Ensure the directory exists and the server user has read permission");
std::process::exit(1);
}
// Validate certificate files (always required for TLS)
let cert_path = Path::new(&host_config.cert);
if !cert_path.exists() {
eprintln!("Error: Certificate file '{}' for host '{}' does not exist", host_config.cert, hostname);
eprintln!("Generate or obtain TLS certificates for your domain");
std::process::exit(1);
}
if let Err(e) = std::fs::File::open(cert_path) {
eprintln!("Error: Cannot read certificate file '{}' for host '{}': {}", host_config.cert, hostname, e);
eprintln!("Ensure the file exists and the server user has read permission");
std::process::exit(1);
}
let key_path = Path::new(&host_config.key);
if !key_path.exists() {
eprintln!("Error: Private key file '{}' for host '{}' does not exist", host_config.key, hostname);
eprintln!("Generate or obtain TLS private key for your domain");
std::process::exit(1);
}
if let Err(e) = std::fs::File::open(key_path) {
eprintln!("Error: Cannot read private key file '{}' for host '{}': {}", host_config.key, hostname, e);
eprintln!("Ensure the file exists and the server user has read permission");
std::process::exit(1);
}
}
if config.cert.is_none() {
eprintln!("Error: 'cert' field is required in config file");
eprintln!("Add: cert = \"/path/to/certificate.pem\"");
std::process::exit(1);
}
if config.key.is_none() {
eprintln!("Error: 'key' field is required in config file");
eprintln!("Add: key = \"/path/to/private-key.pem\"");
std::process::exit(1);
}
if config.hostname.is_none() {
eprintln!("Error: 'hostname' field is required in config file");
eprintln!("Add: hostname = \"your.domain.com\"");
std::process::exit(1);
}
// Validate filesystem
let root_path = std::path::Path::new(config.root.as_ref().unwrap());
if !root_path.exists() {
eprintln!("Error: Root directory '{}' does not exist", config.root.as_ref().unwrap());
eprintln!("Create the directory and add your Gemini files (.gmi, .txt, images)");
std::process::exit(1);
}
if !root_path.is_dir() {
eprintln!("Error: Root path '{}' is not a directory", config.root.as_ref().unwrap());
eprintln!("The 'root' field must point to a directory containing your content");
std::process::exit(1);
}
if let Err(e) = std::fs::read_dir(root_path) {
eprintln!("Error: Cannot read root directory '{}': {}", config.root.as_ref().unwrap(), e);
eprintln!("Ensure the directory exists and the server user has read permission");
std::process::exit(1);
}
let cert_path = std::path::Path::new(config.cert.as_ref().unwrap());
if !cert_path.exists() {
eprintln!("Error: Certificate file '{}' does not exist", config.cert.as_ref().unwrap());
eprintln!("Generate or obtain TLS certificates for your domain");
std::process::exit(1);
}
if let Err(e) = std::fs::File::open(cert_path) {
eprintln!("Error: Cannot read certificate file '{}': {}", config.cert.as_ref().unwrap(), e);
eprintln!("Ensure the file exists and the server user has read permission");
std::process::exit(1);
}
let key_path = std::path::Path::new(config.key.as_ref().unwrap());
if !key_path.exists() {
eprintln!("Error: Private key file '{}' does not exist", config.key.as_ref().unwrap());
eprintln!("Generate or obtain TLS private key for your domain");
std::process::exit(1);
}
if let Err(e) = std::fs::File::open(key_path) {
eprintln!("Error: Cannot read private key file '{}': {}", config.key.as_ref().unwrap(), e);
eprintln!("Ensure the file exists and the server user has read permission");
std::process::exit(1);
}
// Initialize logging after config validation
// Initialize logging
let log_level = config.log_level.as_deref().unwrap_or("info");
init_logging(log_level);
// Extract validated config values
let root = config.root.unwrap();
let cert_path = config.cert.unwrap();
let key_path = config.key.unwrap();
let bind_host = config.bind_host.unwrap_or_else(|| "0.0.0.0".to_string());
let hostname = config.hostname.unwrap();
let port = config.port.unwrap_or(1965);
// Validate max concurrent requests
let max_concurrent_requests = config.max_concurrent_requests.unwrap_or(1000);
if max_concurrent_requests == 0 || max_concurrent_requests > 1_000_000 {
@ -166,41 +159,39 @@ async fn main() {
#[cfg(not(debug_assertions))]
let test_processing_delay = 0;
// Validate directory
let dir_path = Path::new(&root);
if !dir_path.exists() || !dir_path.is_dir() {
eprintln!("Error: Directory '{}' does not exist or is not a directory", root);
std::process::exit(1);
}
// Load TLS certificates
let certs = tls::load_certs(&cert_path).unwrap();
let key = tls::load_private_key(&key_path).unwrap();
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key).unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(format!("{}:{}", bind_host, port)).await.unwrap();
// Print startup information
print_startup_info(&bind_host, port, &root, &cert_path, &key_path, Some(log_level), max_concurrent_requests);
print_startup_info(&config, &config.hosts);
// Phase 3: TLS mode (always enabled)
let tls_config = create_tls_config(&config.hosts)?;
let acceptor = TlsAcceptor::from(tls_config);
println!("Starting Pollux Gemini Server with Virtual Host support...");
let bind_host = config.bind_host.as_deref().unwrap_or("0.0.0.0");
let port = config.port.unwrap_or(1965);
let listener = TcpListener::bind(format!("{}:{}", bind_host, port)).await?;
println!("Listening on {}:{} for all virtual hosts (TLS enabled)", bind_host, port);
loop {
let (stream, _) = listener.accept().await.unwrap();
tracing::debug!("Accepted connection from {}", stream.peer_addr().unwrap_or_else(|_| "unknown".parse().unwrap()));
let acceptor = acceptor.clone();
let dir = root.clone();
let expected_hostname = hostname.clone();
let (stream, _) = listener.accept().await?;
let hosts_clone = config.hosts.clone();
let acceptor_clone = acceptor.clone();
let max_concurrent = max_concurrent_requests;
let test_delay = test_processing_delay;
tokio::spawn(async move {
if let Ok(stream) = acceptor.accept(stream).await {
if let Err(e) = server::handle_connection(stream, &dir, &expected_hostname, port, max_concurrent, test_delay).await {
tracing::error!("Error handling connection: {}", e);
// TLS connection with hostname routing
match acceptor_clone.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = server::handle_connection(tls_stream, &hosts_clone, max_concurrent, test_delay).await {
eprintln!("Error handling connection: {}", e);
}
}
Err(e) => {
eprintln!("TLS handshake failed: {}", e);
}
}
});

View file

@ -6,6 +6,7 @@ pub enum PathResolutionError {
NotFound,
}
#[allow(dead_code)]
pub fn parse_gemini_url(request: &str, hostname: &str, expected_port: u16) -> Result<String, ()> {
if let Some(url) = request.strip_prefix("gemini://") {
let host_port_end = url.find('/').unwrap_or(url.len());

View file

@ -1,5 +1,4 @@
use crate::request::{parse_gemini_url, resolve_file_path, get_mime_type, PathResolutionError};
use crate::logging::RequestLogger;
use crate::request::{resolve_file_path, get_mime_type};
use std::fs;
use std::io;
use std::path::Path;
@ -9,39 +8,42 @@ use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
use tokio_rustls::server::TlsStream;
static ACTIVE_REQUESTS: AtomicUsize = AtomicUsize::new(0);
pub async fn serve_file(
stream: &mut TlsStream<TcpStream>,
file_path: &Path,
request: &str,
) -> io::Result<()> {
if file_path.exists() && file_path.is_file() {
let mime_type = get_mime_type(file_path);
let header = format!("20 {}\r\n", mime_type);
stream.write_all(header.as_bytes()).await?;
// Log success after sending header
let client_ip = match stream.get_ref().0.peer_addr() {
Ok(addr) => addr.to_string(),
Err(_) => "unknown".to_string(),
};
let request_path = request.strip_prefix("gemini://localhost").unwrap_or(request);
tracing::info!("{} \"{}\" 20 \"Success\"", client_ip, request_path);
// Then send body
let content = fs::read(file_path)?;
stream.write_all(&content).await?;
stream.flush().await?;
Ok(())
} else {
Err(tokio::io::Error::new(tokio::io::ErrorKind::NotFound, "File not found"))
/// Extract hostname and path from a Gemini URL
/// Returns (hostname, path) or error for invalid URLs
pub fn extract_hostname_and_path(request: &str) -> Result<(String, String), ()> {
if !request.starts_with("gemini://") {
return Err(());
}
let url_part = &request[9..]; // Remove "gemini://" prefix
let slash_pos = url_part.find('/').unwrap_or(url_part.len());
let hostname = &url_part[..slash_pos];
let path = if slash_pos < url_part.len() {
url_part[slash_pos..].to_string()
} else {
"/".to_string()
};
// Basic hostname validation
if hostname.is_empty() || hostname.contains('/') {
return Err(());
}
// URL decode the path
let decoded_path = urlencoding::decode(&path)
.map_err(|_| ())?;
Ok((hostname.to_string(), decoded_path.to_string()))
}
pub async fn handle_connection(
mut stream: TlsStream<TcpStream>,
dir: &str,
hostname: &str,
expected_port: u16,
hosts: &std::collections::HashMap<String, crate::config::HostConfig>,
max_concurrent_requests: usize,
_test_processing_delay: u64,
) -> io::Result<()> {
@ -70,12 +72,13 @@ pub async fn handle_connection(
let request = String::from_utf8_lossy(&request_buf).trim().to_string();
// Initialize logger early for all request types
let logger = RequestLogger::new(&stream, request.clone());
// TODO: Phase 3 - re-enable RequestLogger with proper TLS stream
// let logger = RequestLogger::new(&stream, request.clone());
// Check concurrent request limit after TLS handshake and request read
// Check concurrent request limit after connection establishment
let current = ACTIVE_REQUESTS.fetch_add(1, Ordering::Relaxed);
if current >= max_concurrent_requests {
logger.log_error(41, "Concurrent request limit exceeded");
tracing::error!("Concurrent request limit exceeded");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
// Rate limited - send proper 41 response
send_response(&mut stream, "41 Server unavailable\r\n").await?;
@ -85,84 +88,65 @@ pub async fn handle_connection(
// Process the request
// Validate request
if request.is_empty() {
logger.log_error(59, "Empty request");
tracing::error!("Empty request");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
}
if request.len() > 1024 {
logger.log_error(59, "Request too large");
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
}
// Parse Gemini URL
let path = match parse_gemini_url(&request, hostname, expected_port) {
Ok(p) => p,
// Extract hostname and path
let (hostname, path) = match extract_hostname_and_path(&request) {
Ok(result) => result,
Err(_) => {
logger.log_error(59, "Invalid URL format");
tracing::error!("Invalid URL format: {}", request);
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "59 Bad Request\r\n").await;
}
};
// Route to appropriate host
let host_config = match hosts.get(&hostname) {
Some(config) => config,
None => {
tracing::error!("Unknown hostname: {}", hostname);
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "53 Proxy request refused\r\n").await;
}
};
// TESTING ONLY: Add delay for rate limiting tests (debug builds only)
#[cfg(debug_assertions)]
if _test_processing_delay > 0 {
tokio::time::sleep(tokio::time::Duration::from_secs(_test_processing_delay)).await;
}
// Resolve file path with security
let file_path = match resolve_file_path(&path, dir) {
Ok(fp) => fp,
Err(PathResolutionError::NotFound) => {
logger.log_error(51, "File not found");
// Resolve file path
let file_path = match resolve_file_path(&path, &host_config.root) {
Ok(path) => path,
Err(_) => {
tracing::error!("Path resolution failed for: {}", path);
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return send_response(&mut stream, "51 Not found\r\n").await;
}
};
// No delay for normal operation
// Processing complete
tracing::info!("{} 20 Success", request);
// Serve the file
match serve_file(&mut stream, &file_path, &request).await {
Ok(_) => {
// Success already logged in serve_file
}
Err(_) => {
// File transmission failed
logger.log_error(51, "File transmission failed");
Ok(_) => {}
Err(e) => {
tracing::error!("Error serving file {}: {}", file_path.display(), e);
let _ = send_response(&mut stream, "51 Not found\r\n").await;
}
}
}
Ok(Err(e)) => {
// Read failed, check error type
let request_str = String::from_utf8_lossy(&request_buf).trim().to_string();
let logger = RequestLogger::new(&stream, request_str);
match e.kind() {
tokio::io::ErrorKind::InvalidData => {
logger.log_error(59, "Request too large");
let _ = send_response(&mut stream, "59 Bad Request\r\n").await;
},
_ => {
logger.log_error(59, "Bad request");
let _ = send_response(&mut stream, "59 Bad Request\r\n").await;
}
}
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
},
tracing::error!("Request read error: {}", e);
let _ = send_response(&mut stream, "59 Bad Request\r\n").await;
}
Err(_) => {
// Timeout
let request_str = String::from_utf8_lossy(&request_buf).trim().to_string();
let logger = RequestLogger::new(&stream, request_str);
logger.log_error(41, "Server unavailable");
let _ = send_response(&mut stream, "41 Server unavailable\r\n").await;
ACTIVE_REQUESTS.fetch_sub(1, Ordering::Relaxed);
return Ok(());
tracing::error!("Request timeout");
let _ = send_response(&mut stream, "59 Bad Request\r\n").await;
}
}
@ -170,10 +154,36 @@ pub async fn handle_connection(
Ok(())
}
async fn send_response(
stream: &mut TlsStream<TcpStream>,
async fn serve_file<S>(
stream: &mut S,
file_path: &Path,
_request: &str,
) -> io::Result<()>
where
S: AsyncWriteExt + Unpin,
{
if file_path.exists() && file_path.is_file() {
let mime_type = get_mime_type(file_path);
let response_header = format!("20 {}\r\n", mime_type);
send_response(stream, &response_header).await?;
// Read and send file content
let content = fs::read(file_path)?;
stream.write_all(&content).await?;
stream.flush().await?;
} else {
return Err(io::Error::new(io::ErrorKind::NotFound, "File not found"));
}
Ok(())
}
async fn send_response<S>(
stream: &mut S,
response: &str,
) -> io::Result<()> {
) -> io::Result<()>
where
S: AsyncWriteExt + Unpin,
{
stream.write_all(response.as_bytes()).await?;
stream.flush().await?;
Ok(())