mod auth; mod db; mod endpoints; mod errors; mod file_storage; mod prelude; use std::{env, net::Ipv4Addr}; use auth::HashedBytes; use axum::{extract::FromRef, routing::post, Router}; use file_storage::FileStorage; use tokio::{net::TcpListener, signal}; type Pool = sqlx::postgres::PgPool; #[derive(Clone, FromRef)] struct AppState { pool: Pool, storage: FileStorage, } async fn create_test_users(pool: &Pool) -> anyhow::Result<()> { let count = sqlx::query!("SELECT count(user_id) FROM users") .fetch_one(pool) .await? .count .unwrap_or(0); if count > 0 { return Ok(()); } let hash1 = HashedBytes::hash_bytes(b"Password1").as_bytes(); let hash2 = HashedBytes::hash_bytes(b"Password2").as_bytes(); tokio::try_join!( db::users::create_user("Test1", "test1@example.com", &hash1, pool), db::users::create_user("Test2", "test2@example.com", &hash2, pool) )?; Ok(()) } #[tokio::main] async fn main() -> anyhow::Result<()> { let _ = dotenvy::dotenv(); tracing_subscriber::fmt::init(); auth::force_init_keys(); let pool = match env::var("DATABASE_URL") { Ok(url) => Pool::connect(&url).await?, Err(err) => anyhow::bail!("Error getting database url: {err}"), }; sqlx::migrate!().run(&pool).await?; if let Ok("1") = env::var("DEVELOPMENT").as_deref().map(str::trim_ascii) { create_test_users(&pool).await?; } let state = AppState { pool, storage: FileStorage::new()?, }; let router = app(state); let addr = (Ipv4Addr::UNSPECIFIED, 3000); let listener = TcpListener::bind(addr).await?; axum::serve(listener, router) .with_graceful_shutdown(shutdown_signal()) .await?; Ok(()) } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { () = ctrl_c => {}, () = terminate => {}, } } fn app(state: AppState) -> Router { use axum::{http::header, routing::get}; use endpoints::{ file, folder, permissions::{self, get_top_level::get_top_level}, users, }; use tower_http::ServiceBuilderExt as _; let middleware = tower::ServiceBuilder::new() .sensitive_headers([header::AUTHORIZATION, header::COOKIE]) .trace_for_http() .compression(); // Build route service Router::new() .route( "/files", get(file::download::download) .post(file::upload::upload) .delete(file::delete::delete) .patch(file::modify::modify), ) .route( "/folders", get(folder::list::list) .post(folder::create::create) .delete(folder::delete::delete), ) .route("/folders/structure", get(folder::get_structure::structure)) .route( "/permissions", get(permissions::get::get) .post(permissions::set::set) .delete(permissions::delete::delete), ) .route( "/permissions/get_top_level_permitted_folders", get(get_top_level), ) .route( "/users", get(users::get::get) .delete(users::delete::delete) .put(users::put::put), ) .route("/users/current", get(users::get::current)) .route("/users/search", get(users::search::search)) .route("/users/register", post(users::register::register)) .route("/users/authorize", post(users::login::login)) .layer(middleware) .with_state(state) }