This repository has been archived on 2024-08-23. You can view files and clone it, but cannot push or open issues or pull requests.
project/src/main.rs

153 lines
4.0 KiB
Rust

mod auth;
mod db;
mod endpoints;
mod errors;
mod file_storage;
mod prelude;
use std::{env, net::Ipv4Addr};
use auth::HashedBytes;
use axum::{extract::FromRef, routing::post, Router};
use file_storage::FileStorage;
use tokio::{net::TcpListener, signal};
type Pool = sqlx::postgres::PgPool;
#[derive(Clone, FromRef)]
struct AppState {
pool: Pool,
storage: FileStorage,
}
async fn create_test_users(pool: &Pool) -> anyhow::Result<()> {
let count = sqlx::query!("SELECT count(user_id) FROM users")
.fetch_one(pool)
.await?
.count
.unwrap_or(0);
if count > 0 {
return Ok(());
}
let hash1 = HashedBytes::hash_bytes(b"Password1").as_bytes();
let hash2 = HashedBytes::hash_bytes(b"Password2").as_bytes();
tokio::try_join!(
db::users::create_user("Test1", "test1@example.com", &hash1, pool),
db::users::create_user("Test2", "test2@example.com", &hash2, pool)
)?;
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _ = dotenvy::dotenv();
tracing_subscriber::fmt::init();
auth::force_init_keys();
let pool = match env::var("DATABASE_URL") {
Ok(url) => Pool::connect(&url).await?,
Err(err) => anyhow::bail!("Error getting database url: {err}"),
};
sqlx::migrate!().run(&pool).await?;
if let Ok("1") = env::var("DEVELOPMENT").as_deref().map(str::trim_ascii) {
create_test_users(&pool).await?;
}
let state = AppState {
pool,
storage: FileStorage::new()?,
};
let router = app(state);
let addr = (Ipv4Addr::UNSPECIFIED, 3000);
let listener = TcpListener::bind(addr).await?;
axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
fn app(state: AppState) -> Router {
use axum::{http::header, routing::get};
use endpoints::{
file, folder,
permissions::{self, get_top_level::get_top_level},
users,
};
use tower_http::ServiceBuilderExt as _;
let middleware = tower::ServiceBuilder::new()
.sensitive_headers([header::AUTHORIZATION, header::COOKIE])
.trace_for_http()
.compression();
// Build route service
Router::new()
.route(
"/files",
get(file::download::download)
.post(file::upload::upload)
.delete(file::delete::delete)
.patch(file::modify::modify),
)
.route(
"/folders",
get(folder::list::list)
.post(folder::create::create)
.delete(folder::delete::delete),
)
.route("/folders/structure", get(folder::get_structure::structure))
.route(
"/permissions",
get(permissions::get::get)
.post(permissions::set::set)
.delete(permissions::delete::delete),
)
.route(
"/permissions/get_top_level_permitted_folders",
get(get_top_level),
)
.route(
"/users",
get(users::get::get)
.delete(users::delete::delete)
.put(users::put::put),
)
.route("/users/current", get(users::get::current))
.route("/users/search", get(users::search::search))
.route("/users/register", post(users::register::register))
.route("/users/authorize", post(users::login::login))
.layer(middleware)
.with_state(state)
}