Replace custom logging with tracing crate and RUST_LOG env var
- Remove custom logging module and init_logging function - Update main.rs to use tracing_subscriber with EnvFilter - Remove log_level from global config structure - Update documentation and tests to use RUST_LOG - Format long lines in config.rs and test files for better readability
This commit is contained in:
parent
50a4d9bc75
commit
55fe47b172
15 changed files with 787 additions and 459 deletions
|
|
@ -7,7 +7,6 @@ pub struct Config {
|
|||
// Global defaults (optional)
|
||||
pub bind_host: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub log_level: Option<String>,
|
||||
pub max_concurrent_requests: Option<usize>,
|
||||
|
||||
// Per-hostname configurations
|
||||
|
|
@ -21,7 +20,7 @@ pub struct HostConfig {
|
|||
pub key: String,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub port: Option<u16>, // override global port
|
||||
pub port: Option<u16>, // override global port
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub log_level: Option<String>, // override global log level
|
||||
|
|
@ -34,7 +33,6 @@ pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
|
|||
// Extract global settings
|
||||
let bind_host = extract_string(&toml_value, "bind_host");
|
||||
let port = extract_u16(&toml_value, "port");
|
||||
let log_level = extract_string(&toml_value, "log_level");
|
||||
let max_concurrent_requests = extract_usize(&toml_value, "max_concurrent_requests");
|
||||
|
||||
// Extract host configurations
|
||||
|
|
@ -43,7 +41,10 @@ pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
|
|||
if let Some(table) = toml_value.as_table() {
|
||||
for (key, value) in table {
|
||||
// Skip global config keys
|
||||
if matches!(key.as_str(), "bind_host" | "port" | "log_level" | "max_concurrent_requests") {
|
||||
if matches!(
|
||||
key.as_str(),
|
||||
"bind_host" | "port" | "max_concurrent_requests"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -57,7 +58,11 @@ pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
|
|||
|
||||
// Validate hostname
|
||||
if !is_valid_hostname(key) {
|
||||
return Err(format!("Invalid hostname '{}'. Hostnames must be valid DNS names.", key).into());
|
||||
return Err(format!(
|
||||
"Invalid hostname '{}'. Hostnames must be valid DNS names.",
|
||||
key
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Validate that root directory exists
|
||||
|
|
@ -96,34 +101,53 @@ pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
|
|||
Ok(Config {
|
||||
bind_host,
|
||||
port,
|
||||
log_level,
|
||||
max_concurrent_requests,
|
||||
hosts,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_string(value: &Value, key: &str) -> Option<String> {
|
||||
value.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
|
||||
value
|
||||
.get(key)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn extract_string_from_table(table: &toml::map::Map<String, Value>, key: &str) -> Option<String> {
|
||||
table.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
|
||||
table
|
||||
.get(key)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn extract_u16(value: &Value, key: &str) -> Option<u16> {
|
||||
value.get(key).and_then(|v| v.as_integer()).and_then(|i| u16::try_from(i).ok())
|
||||
value
|
||||
.get(key)
|
||||
.and_then(|v| v.as_integer())
|
||||
.and_then(|i| u16::try_from(i).ok())
|
||||
}
|
||||
|
||||
fn extract_u16_from_table(table: &toml::map::Map<String, Value>, key: &str) -> Option<u16> {
|
||||
table.get(key).and_then(|v| v.as_integer()).and_then(|i| u16::try_from(i).ok())
|
||||
table
|
||||
.get(key)
|
||||
.and_then(|v| v.as_integer())
|
||||
.and_then(|i| u16::try_from(i).ok())
|
||||
}
|
||||
|
||||
fn extract_usize(value: &Value, key: &str) -> Option<usize> {
|
||||
value.get(key).and_then(|v| v.as_integer()).and_then(|i| usize::try_from(i).ok())
|
||||
value
|
||||
.get(key)
|
||||
.and_then(|v| v.as_integer())
|
||||
.and_then(|i| usize::try_from(i).ok())
|
||||
}
|
||||
|
||||
fn extract_required_string(table: &toml::map::Map<String, Value>, key: &str, section: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
table.get(key)
|
||||
fn extract_required_string(
|
||||
table: &toml::map::Map<String, Value>,
|
||||
key: &str,
|
||||
section: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
table
|
||||
.get(key)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| format!("Missing required field '{}' in [{}] section", key, section).into())
|
||||
|
|
@ -204,7 +228,9 @@ mod tests {
|
|||
assert!(!is_valid_hostname("invalid..com"));
|
||||
assert!(!is_valid_hostname("invalid.com."));
|
||||
assert!(!is_valid_hostname("inval!d.com"));
|
||||
assert!(is_valid_hostname("too.long.label.that.exceeds.sixty.three.characters.example.com"));
|
||||
assert!(is_valid_hostname(
|
||||
"too.long.label.that.exceeds.sixty.three.characters.example.com"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -220,14 +246,19 @@ mod tests {
|
|||
fs::write(&cert_path, "dummy cert").unwrap();
|
||||
fs::write(&key_path, "dummy key").unwrap();
|
||||
|
||||
let content = format!(r#"
|
||||
let content = format!(
|
||||
r#"
|
||||
["example.com"]
|
||||
root = "{}"
|
||||
cert = "{}"
|
||||
key = "{}"
|
||||
port = 1965
|
||||
log_level = "info"
|
||||
"#, root_dir.display(), cert_path.display(), key_path.display());
|
||||
"#,
|
||||
root_dir.display(),
|
||||
cert_path.display(),
|
||||
key_path.display()
|
||||
);
|
||||
fs::write(&config_path, content).unwrap();
|
||||
|
||||
let config = load_config(config_path.to_str().unwrap()).unwrap();
|
||||
|
|
@ -262,7 +293,8 @@ mod tests {
|
|||
fs::write(&site2_cert, "dummy cert 2").unwrap();
|
||||
fs::write(&site2_key, "dummy key 2").unwrap();
|
||||
|
||||
let content = format!(r#"
|
||||
let content = format!(
|
||||
r#"
|
||||
["site1.com"]
|
||||
root = "{}"
|
||||
cert = "{}"
|
||||
|
|
@ -273,8 +305,14 @@ mod tests {
|
|||
cert = "{}"
|
||||
key = "{}"
|
||||
port = 1966
|
||||
"#, site1_root.display(), site1_cert.display(), site1_key.display(),
|
||||
site2_root.display(), site2_cert.display(), site2_key.display());
|
||||
"#,
|
||||
site1_root.display(),
|
||||
site1_cert.display(),
|
||||
site1_key.display(),
|
||||
site2_root.display(),
|
||||
site2_cert.display(),
|
||||
site2_key.display()
|
||||
);
|
||||
fs::write(&config_path, content).unwrap();
|
||||
|
||||
let config = load_config(config_path.to_str().unwrap()).unwrap();
|
||||
|
|
@ -303,7 +341,10 @@ mod tests {
|
|||
|
||||
let result = load_config(config_path.to_str().unwrap());
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("No host configurations found"));
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("No host configurations found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -347,4 +388,4 @@ mod tests {
|
|||
// Config parsing will fail if required fields are missing
|
||||
assert!(load_config(config_path.to_str().unwrap()).is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
107
src/logging.rs
107
src/logging.rs
|
|
@ -1,106 +1,5 @@
|
|||
use tokio::net::TcpStream;
|
||||
use std::time::Instant;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing_subscriber::fmt::format::Writer;
|
||||
use tracing_subscriber::fmt::FormatFields;
|
||||
|
||||
struct CleanLogFormatter;
|
||||
|
||||
impl<S, N> tracing_subscriber::fmt::FormatEvent<S, N> for CleanLogFormatter
|
||||
where
|
||||
S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
|
||||
N: for<'a> tracing_subscriber::fmt::FormatFields<'a> + 'static,
|
||||
{
|
||||
fn format_event(
|
||||
&self,
|
||||
ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
|
||||
mut writer: Writer<'_>,
|
||||
event: &tracing::Event<'_>,
|
||||
) -> std::fmt::Result {
|
||||
// Write timestamp
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
write!(writer, "{}-{:02}-{:02}T{:02}:{:02}:{:02} ",
|
||||
now.year(), now.month() as u8, now.day(),
|
||||
now.hour(), now.minute(), now.second())?;
|
||||
|
||||
// Write level
|
||||
let level = event.metadata().level();
|
||||
write!(writer, "{} ", level)?;
|
||||
|
||||
// Write the message
|
||||
ctx.format_fields(writer.by_ref(), event)?;
|
||||
|
||||
writeln!(writer)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct RequestLogger {
|
||||
request_url: String,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
impl RequestLogger {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(_stream: &TlsStream<TcpStream>, request_url: String) -> Self {
|
||||
Self {
|
||||
request_url,
|
||||
start_time: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn log_error(self, status_code: u8, error_message: &str) {
|
||||
let level = match status_code {
|
||||
41 | 51 => tracing::Level::WARN,
|
||||
59 => tracing::Level::ERROR,
|
||||
_ => tracing::Level::ERROR,
|
||||
};
|
||||
|
||||
let request_path = self.request_url.strip_prefix("gemini://localhost").unwrap_or(&self.request_url);
|
||||
|
||||
match level {
|
||||
tracing::Level::WARN => tracing::warn!("unknown \"{}\" {} \"{}\"", request_path, status_code, error_message),
|
||||
tracing::Level::ERROR => tracing::error!("unknown \"{}\" {} \"{}\"", request_path, status_code, error_message),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn extract_client_ip(stream: &TlsStream<TcpStream>) -> String {
|
||||
let (tcp_stream, _) = stream.get_ref();
|
||||
match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr.to_string(),
|
||||
Err(_) => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_logging(level: &str) {
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
let level = match level.to_lowercase().as_str() {
|
||||
"error" => tracing::Level::ERROR,
|
||||
"warn" => tracing::Level::WARN,
|
||||
"info" => tracing::Level::INFO,
|
||||
"debug" => tracing::Level::DEBUG,
|
||||
"trace" => tracing::Level::TRACE,
|
||||
_ => {
|
||||
eprintln!("Warning: Invalid log level '{}', defaulting to 'info'", level);
|
||||
tracing::Level::INFO
|
||||
}
|
||||
};
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer()
|
||||
.event_format(CleanLogFormatter))
|
||||
.with(tracing_subscriber::filter::LevelFilter::from_level(level))
|
||||
.init();
|
||||
}
|
||||
// Logging module - now unused as logging is handled directly in main.rs
|
||||
// All logging functionality moved to main.rs with RUST_LOG environment variable support
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
@ -109,4 +8,4 @@ mod tests {
|
|||
// Basic test to ensure logging module compiles
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
144
src/main.rs
144
src/main.rs
|
|
@ -1,8 +1,8 @@
|
|||
mod config;
|
||||
mod tls;
|
||||
mod logging;
|
||||
mod request;
|
||||
mod server;
|
||||
mod logging;
|
||||
mod tls;
|
||||
|
||||
use clap::Parser;
|
||||
use rustls::ServerConfig;
|
||||
|
|
@ -10,9 +10,11 @@ use std::path::Path;
|
|||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use logging::init_logging;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
fn create_tls_config(hosts: &std::collections::HashMap<String, config::HostConfig>) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error>> {
|
||||
fn create_tls_config(
|
||||
hosts: &std::collections::HashMap<String, config::HostConfig>,
|
||||
) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error>> {
|
||||
// For Phase 3, we'll use the first host's certificate for all connections
|
||||
// TODO: Phase 4 could implement proper SNI-based certificate selection
|
||||
let first_host = hosts.values().next().ok_or("No hosts configured")?;
|
||||
|
|
@ -28,7 +30,10 @@ fn create_tls_config(hosts: &std::collections::HashMap<String, config::HostConfi
|
|||
Ok(Arc::new(config))
|
||||
}
|
||||
|
||||
fn print_startup_info(config: &config::Config, hosts: &std::collections::HashMap<String, config::HostConfig>) {
|
||||
fn print_startup_info(
|
||||
config: &config::Config,
|
||||
hosts: &std::collections::HashMap<String, config::HostConfig>,
|
||||
) {
|
||||
println!("Pollux Gemini Server (Virtual Host Mode)");
|
||||
println!("Configured hosts:");
|
||||
for (hostname, host_config) in hosts {
|
||||
|
|
@ -41,17 +46,13 @@ fn print_startup_info(config: &config::Config, hosts: &std::collections::HashMap
|
|||
if let Some(port) = config.port {
|
||||
println!(" Default port: {}", port);
|
||||
}
|
||||
if let Some(ref level) = config.log_level {
|
||||
println!(" Log level: {}", level);
|
||||
}
|
||||
|
||||
if let Some(max_concurrent) = config.max_concurrent_requests {
|
||||
println!(" Max concurrent requests: {}", max_concurrent);
|
||||
}
|
||||
println!(); // Add spacing before connections start
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
|
|
@ -64,8 +65,6 @@ struct Args {
|
|||
test_processing_delay: Option<u64>,
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
|
|
@ -73,21 +72,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// Load config
|
||||
let config_path = args.config.as_deref().unwrap_or("/etc/pollux/config.toml");
|
||||
if !std::path::Path::new(&config_path).exists() {
|
||||
eprintln!("Error: Config file '{}' not found", config_path);
|
||||
eprintln!("Create the config file with virtual host sections like:");
|
||||
eprintln!("[example.com]");
|
||||
eprintln!("root = \"/srv/gemini/example.com/gemini/\"");
|
||||
eprintln!("cert = \"/srv/gemini/example.com/tls/fullchain.pem\"");
|
||||
eprintln!("key = \"/srv/gemini/example.com/tls/privkey.pem\"");
|
||||
// User guidance goes to stdout BEFORE initializing tracing
|
||||
// Use direct stderr for error, stdout for guidance
|
||||
use std::io::Write;
|
||||
let mut stderr = std::io::stderr();
|
||||
let mut stdout = std::io::stdout();
|
||||
|
||||
writeln!(stderr, "Config file '{}' not found", config_path).unwrap();
|
||||
writeln!(
|
||||
stdout,
|
||||
"Create the config file with virtual host sections like:"
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(stdout, "[example.com]").unwrap();
|
||||
writeln!(stdout, "root = \"/var/gemini\"").unwrap();
|
||||
writeln!(stdout, "cert = \"/etc/pollux/tls/cert.pem\"").unwrap();
|
||||
writeln!(stdout, "key = \"/etc/pollux/tls/key.pem\"").unwrap();
|
||||
|
||||
stdout.flush().unwrap();
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize logging with RUST_LOG support AFTER basic config checks
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
// Load and parse config
|
||||
let config = match config::load_config(config_path) {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
eprintln!("Error: Failed to parse config file '{}': {}", config_path, e);
|
||||
eprintln!("Check the TOML syntax and ensure host sections are properly formatted.");
|
||||
tracing::error!("Failed to parse config file '{}': {}", config_path, e);
|
||||
tracing::error!(
|
||||
"Check the TOML syntax and ensure host sections are properly formatted."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
|
@ -97,61 +116,88 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// Validate root directory exists and is readable
|
||||
let root_path = Path::new(&host_config.root);
|
||||
if !root_path.exists() {
|
||||
eprintln!("Error: Root directory '{}' for host '{}' does not exist", host_config.root, hostname);
|
||||
eprintln!("Create the directory and add your Gemini files (.gmi, .txt, images)");
|
||||
tracing::error!(
|
||||
"Root directory '{}' for host '{}' does not exist",
|
||||
host_config.root,
|
||||
hostname
|
||||
);
|
||||
tracing::error!("Create the directory and add your Gemini files (.gmi, .txt, images)");
|
||||
std::process::exit(1);
|
||||
}
|
||||
if !root_path.is_dir() {
|
||||
eprintln!("Error: Root path '{}' for host '{}' is not a directory", host_config.root, hostname);
|
||||
eprintln!("The 'root' field must point to a directory containing your content");
|
||||
tracing::error!(
|
||||
"Root path '{}' for host '{}' is not a directory",
|
||||
host_config.root,
|
||||
hostname
|
||||
);
|
||||
tracing::error!("The 'root' field must point to a directory containing your content");
|
||||
std::process::exit(1);
|
||||
}
|
||||
if let Err(e) = std::fs::read_dir(root_path) {
|
||||
eprintln!("Error: Cannot read root directory '{}' for host '{}': {}", host_config.root, hostname, e);
|
||||
eprintln!("Ensure the directory exists and the server user has read permission");
|
||||
tracing::error!(
|
||||
"Cannot read root directory '{}' for host '{}': {}",
|
||||
host_config.root,
|
||||
hostname,
|
||||
e
|
||||
);
|
||||
tracing::error!("Ensure the directory exists and the server user has read permission");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Validate certificate files (always required for TLS)
|
||||
let cert_path = Path::new(&host_config.cert);
|
||||
if !cert_path.exists() {
|
||||
eprintln!("Error: Certificate file '{}' for host '{}' does not exist", host_config.cert, hostname);
|
||||
eprintln!(
|
||||
"Error: Certificate file '{}' for host '{}' does not exist",
|
||||
host_config.cert, hostname
|
||||
);
|
||||
eprintln!("Generate or obtain TLS certificates for your domain");
|
||||
std::process::exit(1);
|
||||
}
|
||||
if let Err(e) = std::fs::File::open(cert_path) {
|
||||
eprintln!("Error: Cannot read certificate file '{}' for host '{}': {}", host_config.cert, hostname, e);
|
||||
eprintln!("Ensure the file exists and the server user has read permission");
|
||||
tracing::error!(
|
||||
"Cannot read certificate file '{}' for host '{}': {}",
|
||||
host_config.cert,
|
||||
hostname,
|
||||
e
|
||||
);
|
||||
tracing::error!("Ensure the file exists and the server user has read permission");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let key_path = Path::new(&host_config.key);
|
||||
if !key_path.exists() {
|
||||
eprintln!("Error: Private key file '{}' for host '{}' does not exist", host_config.key, hostname);
|
||||
eprintln!("Generate or obtain TLS private key for your domain");
|
||||
tracing::error!(
|
||||
"Private key file '{}' for host '{}' does not exist",
|
||||
host_config.key,
|
||||
hostname
|
||||
);
|
||||
tracing::error!("Generate or obtain TLS private key for your domain");
|
||||
std::process::exit(1);
|
||||
}
|
||||
if let Err(e) = std::fs::File::open(key_path) {
|
||||
eprintln!("Error: Cannot read private key file '{}' for host '{}': {}", host_config.key, hostname, e);
|
||||
eprintln!("Ensure the file exists and the server user has read permission");
|
||||
tracing::error!(
|
||||
"Cannot read private key file '{}' for host '{}': {}",
|
||||
host_config.key,
|
||||
hostname,
|
||||
e
|
||||
);
|
||||
tracing::error!("Ensure the file exists and the server user has read permission");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize logging
|
||||
let log_level = config.log_level.as_deref().unwrap_or("info");
|
||||
init_logging(log_level);
|
||||
|
||||
// Validate max concurrent requests
|
||||
let max_concurrent_requests = config.max_concurrent_requests.unwrap_or(1000);
|
||||
if max_concurrent_requests == 0 || max_concurrent_requests > 1_000_000 {
|
||||
eprintln!("Error: max_concurrent_requests must be between 1 and 1,000,000");
|
||||
tracing::error!("max_concurrent_requests must be between 1 and 1,000,000");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// TESTING ONLY: Read delay argument (debug builds only)
|
||||
#[cfg(debug_assertions)]
|
||||
let test_processing_delay = args.test_processing_delay
|
||||
let test_processing_delay = args
|
||||
.test_processing_delay
|
||||
.filter(|&d| d > 0 && d <= 300)
|
||||
.unwrap_or(0);
|
||||
|
||||
|
|
@ -172,7 +218,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
let port = config.port.unwrap_or(1965);
|
||||
|
||||
let listener = TcpListener::bind(format!("{}:{}", bind_host, port)).await?;
|
||||
println!("Listening on {}:{} for all virtual hosts (TLS enabled)", bind_host, port);
|
||||
println!(
|
||||
"Listening on {}:{} for all virtual hosts (TLS enabled)",
|
||||
bind_host, port
|
||||
);
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
|
|
@ -186,14 +235,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// TLS connection with hostname routing
|
||||
match acceptor_clone.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = server::handle_connection(tls_stream, &hosts_clone, max_concurrent, test_delay).await {
|
||||
eprintln!("Error handling connection: {}", e);
|
||||
if let Err(e) = server::handle_connection(
|
||||
tls_stream,
|
||||
&hosts_clone,
|
||||
max_concurrent,
|
||||
test_delay,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Error handling connection: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("TLS handshake failed: {}", e);
|
||||
tracing::error!("TLS handshake failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub fn parse_gemini_url(request: &str, hostname: &str, expected_port: u16) -> Re
|
|||
if let Some(url) = request.strip_prefix("gemini://") {
|
||||
let host_port_end = url.find('/').unwrap_or(url.len());
|
||||
let host_port = &url[..host_port_end];
|
||||
|
||||
|
||||
// Parse host and port
|
||||
let (host, port_str) = if let Some(colon_pos) = host_port.find(':') {
|
||||
let host = &host_port[..colon_pos];
|
||||
|
|
@ -20,21 +20,23 @@ pub fn parse_gemini_url(request: &str, hostname: &str, expected_port: u16) -> Re
|
|||
} else {
|
||||
(host_port, None)
|
||||
};
|
||||
|
||||
|
||||
// Validate host
|
||||
if host != hostname {
|
||||
return Err(()); // Hostname mismatch
|
||||
return Err(()); // Hostname mismatch
|
||||
}
|
||||
|
||||
|
||||
// Validate port
|
||||
let port = port_str
|
||||
.and_then(|p| p.parse::<u16>().ok())
|
||||
.unwrap_or(1965);
|
||||
let port = port_str.and_then(|p| p.parse::<u16>().ok()).unwrap_or(1965);
|
||||
if port != expected_port {
|
||||
return Err(()); // Port mismatch
|
||||
return Err(()); // Port mismatch
|
||||
}
|
||||
|
||||
let path = if host_port_end < url.len() { &url[host_port_end..] } else { "/" };
|
||||
|
||||
let path = if host_port_end < url.len() {
|
||||
&url[host_port_end..]
|
||||
} else {
|
||||
"/"
|
||||
};
|
||||
Ok(path.trim().to_string())
|
||||
} else {
|
||||
Err(())
|
||||
|
|
@ -58,11 +60,11 @@ pub fn resolve_file_path(path: &str, dir: &str) -> Result<PathBuf, PathResolutio
|
|||
} else {
|
||||
Err(PathResolutionError::NotFound)
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(_) => {
|
||||
// Path validation failed - treat as not found
|
||||
Err(PathResolutionError::NotFound)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -90,8 +92,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_gemini_url_valid() {
|
||||
assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/", "gemini.jeena.net", 1965), Ok("/".to_string()));
|
||||
assert_eq!(parse_gemini_url("gemini://gemini.jeena.net/posts/test", "gemini.jeena.net", 1965), Ok("/posts/test".to_string()));
|
||||
assert_eq!(
|
||||
parse_gemini_url("gemini://gemini.jeena.net/", "gemini.jeena.net", 1965),
|
||||
Ok("/".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
parse_gemini_url(
|
||||
"gemini://gemini.jeena.net/posts/test",
|
||||
"gemini.jeena.net",
|
||||
1965
|
||||
),
|
||||
Ok("/posts/test".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -130,14 +142,20 @@ mod tests {
|
|||
#[test]
|
||||
fn test_resolve_file_path_traversal() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
assert_eq!(resolve_file_path("/../etc/passwd", temp_dir.path().to_str().unwrap()), Err(PathResolutionError::NotFound));
|
||||
assert_eq!(
|
||||
resolve_file_path("/../etc/passwd", temp_dir.path().to_str().unwrap()),
|
||||
Err(PathResolutionError::NotFound)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_file_path_not_found() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
// Don't create the file, should return NotFound error
|
||||
assert_eq!(resolve_file_path("/nonexistent.gmi", temp_dir.path().to_str().unwrap()), Err(PathResolutionError::NotFound));
|
||||
assert_eq!(
|
||||
resolve_file_path("/nonexistent.gmi", temp_dir.path().to_str().unwrap()),
|
||||
Err(PathResolutionError::NotFound)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -163,4 +181,4 @@ mod tests {
|
|||
let path = Path::new("test");
|
||||
assert_eq!(get_mime_type(path), "application/octet-stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::request::{resolve_file_path, get_mime_type};
|
||||
use crate::request::{get_mime_type, resolve_file_path};
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
|
|
@ -8,8 +8,6 @@ use tokio::net::TcpStream;
|
|||
use tokio::time::{timeout, Duration};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
|
||||
|
||||
static ACTIVE_REQUESTS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
/// Extract hostname and path from a Gemini URL
|
||||
|
|
@ -35,8 +33,7 @@ pub fn extract_hostname_and_path(request: &str) -> Result<(String, String), ()>
|
|||
}
|
||||
|
||||
// URL decode the path
|
||||
let decoded_path = urlencoding::decode(&path)
|
||||
.map_err(|_| ())?;
|
||||
let decoded_path = urlencoding::decode(&path).map_err(|_| ())?;
|
||||
|
||||
Ok((hostname.to_string(), decoded_path.to_string()))
|
||||
}
|
||||
|
|
@ -54,7 +51,10 @@ pub async fn handle_connection(
|
|||
let read_future = async {
|
||||
loop {
|
||||
if request_buf.len() >= MAX_REQUEST_SIZE {
|
||||
return Err(tokio::io::Error::new(tokio::io::ErrorKind::InvalidData, "Request too large"));
|
||||
return Err(tokio::io::Error::new(
|
||||
tokio::io::ErrorKind::InvalidData,
|
||||
"Request too large",
|
||||
));
|
||||
}
|
||||
let mut byte = [0; 1];
|
||||
stream.read_exact(&mut byte).await?;
|
||||
|
|
@ -154,11 +154,7 @@ pub async fn handle_connection(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve_file<S>(
|
||||
stream: &mut S,
|
||||
file_path: &Path,
|
||||
_request: &str,
|
||||
) -> io::Result<()>
|
||||
async fn serve_file<S>(stream: &mut S, file_path: &Path, _request: &str) -> io::Result<()>
|
||||
where
|
||||
S: AsyncWriteExt + Unpin,
|
||||
{
|
||||
|
|
@ -177,14 +173,11 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_response<S>(
|
||||
stream: &mut S,
|
||||
response: &str,
|
||||
) -> io::Result<()>
|
||||
async fn send_response<S>(stream: &mut S, response: &str) -> io::Result<()>
|
||||
where
|
||||
S: AsyncWriteExt + Unpin,
|
||||
{
|
||||
stream.write_all(response.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,5 +24,8 @@ pub fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> {
|
|||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, "No supported private key found"))
|
||||
}
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"No supported private key found",
|
||||
))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue