diff --git a/src/main.rs b/src/main.rs index d57e006..78639e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,18 +7,12 @@ mod prelude; use std::{env, net::Ipv4Addr}; -use auth::HashedBytes; -use axum::{ - extract::{DefaultBodyLimit, FromRef}, - routing::post, - Router, -}; +use axum::Router; use file_storage::FileStorage; -use tokio::{net::TcpListener, signal}; type Pool = sqlx::postgres::PgPool; -#[derive(Clone, FromRef)] +#[derive(Clone, axum::extract::FromRef)] struct AppState { pool: Pool, storage: FileStorage, @@ -33,8 +27,8 @@ async fn create_test_users(pool: &Pool) -> anyhow::Result<()> { if count > 0 { return Ok(()); } - let hash1 = HashedBytes::hash_bytes(b"Password1").as_bytes(); - let hash2 = HashedBytes::hash_bytes(b"Password2").as_bytes(); + let hash1 = auth::HashedBytes::hash_bytes(b"Password1").as_bytes(); + let hash2 = auth::HashedBytes::hash_bytes(b"Password2").as_bytes(); tokio::try_join!( db::users::create_user("Test1", "test1@example.com", &hash1, pool), @@ -94,7 +88,7 @@ async fn main() -> anyhow::Result<()> { let router = app(state); let addr = (Ipv4Addr::UNSPECIFIED, 3000); - let listener = TcpListener::bind(addr).await?; + let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, router) .with_graceful_shutdown(shutdown_signal()) @@ -104,6 +98,8 @@ async fn main() -> anyhow::Result<()> { } async fn shutdown_signal() { + use tokio::signal; + let ctrl_c = async { signal::ctrl_c() .await @@ -129,7 +125,12 @@ async fn shutdown_signal() { } fn app(state: AppState) -> Router { - use axum::{http::header, routing::get}; + use axum::{ + extract::DefaultBodyLimit, + handler::Handler as _, + http::header, + routing::{get, post}, + }; use endpoints::{ file, folder, permissions::{self, get_top_level::get_top_level}, @@ -157,20 +158,20 @@ fn app(state: AppState) -> Router { } const TEN_GIBIBYTES: usize = 10 * 1024 * 1024 * 1024; + let body_limit = DefaultBodyLimit::max(TEN_GIBIBYTES); + let middleware = tower::ServiceBuilder::new() - .layer(DefaultBodyLimit::max(TEN_GIBIBYTES)) .sensitive_headers([header::AUTHORIZATION, header::COOKIE]) .layer(TraceLayer::new_for_http().make_span_with(SpanMaker)) .compression(); - // Build route service Router::new() .route( "/files", get(file::download::download) - .post(file::upload::upload) + .post(file::upload::upload.layer(body_limit.clone())) .delete(file::delete::delete) - .patch(file::modify::modify), + .patch(file::modify::modify.layer(body_limit)), ) .route( "/folders",