Applying the DefaultBodyLimit layer only to file uploads and modifications

This commit is contained in:
StNicolay 2024-08-15 22:43:00 +03:00
parent a3e4ac2b2e
commit 9e3b9527d3
Signed by: StNicolay
GPG Key ID: 9693D04DCD962B0D

View File

@ -7,18 +7,12 @@ mod prelude;
use std::{env, net::Ipv4Addr}; use std::{env, net::Ipv4Addr};
use auth::HashedBytes; use axum::Router;
use axum::{
extract::{DefaultBodyLimit, FromRef},
routing::post,
Router,
};
use file_storage::FileStorage; use file_storage::FileStorage;
use tokio::{net::TcpListener, signal};
type Pool = sqlx::postgres::PgPool; type Pool = sqlx::postgres::PgPool;
#[derive(Clone, FromRef)] #[derive(Clone, axum::extract::FromRef)]
struct AppState { struct AppState {
pool: Pool, pool: Pool,
storage: FileStorage, storage: FileStorage,
@ -33,8 +27,8 @@ async fn create_test_users(pool: &Pool) -> anyhow::Result<()> {
if count > 0 { if count > 0 {
return Ok(()); return Ok(());
} }
let hash1 = HashedBytes::hash_bytes(b"Password1").as_bytes(); let hash1 = auth::HashedBytes::hash_bytes(b"Password1").as_bytes();
let hash2 = HashedBytes::hash_bytes(b"Password2").as_bytes(); let hash2 = auth::HashedBytes::hash_bytes(b"Password2").as_bytes();
tokio::try_join!( tokio::try_join!(
db::users::create_user("Test1", "test1@example.com", &hash1, pool), db::users::create_user("Test1", "test1@example.com", &hash1, pool),
@ -94,7 +88,7 @@ async fn main() -> anyhow::Result<()> {
let router = app(state); let router = app(state);
let addr = (Ipv4Addr::UNSPECIFIED, 3000); let addr = (Ipv4Addr::UNSPECIFIED, 3000);
let listener = TcpListener::bind(addr).await?; let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router) axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
@ -104,6 +98,8 @@ async fn main() -> anyhow::Result<()> {
} }
async fn shutdown_signal() { async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c() signal::ctrl_c()
.await .await
@ -129,7 +125,12 @@ async fn shutdown_signal() {
} }
fn app(state: AppState) -> Router { fn app(state: AppState) -> Router {
use axum::{http::header, routing::get}; use axum::{
extract::DefaultBodyLimit,
handler::Handler as _,
http::header,
routing::{get, post},
};
use endpoints::{ use endpoints::{
file, folder, file, folder,
permissions::{self, get_top_level::get_top_level}, permissions::{self, get_top_level::get_top_level},
@ -157,20 +158,20 @@ fn app(state: AppState) -> Router {
} }
const TEN_GIBIBYTES: usize = 10 * 1024 * 1024 * 1024; const TEN_GIBIBYTES: usize = 10 * 1024 * 1024 * 1024;
let body_limit = DefaultBodyLimit::max(TEN_GIBIBYTES);
let middleware = tower::ServiceBuilder::new() let middleware = tower::ServiceBuilder::new()
.layer(DefaultBodyLimit::max(TEN_GIBIBYTES))
.sensitive_headers([header::AUTHORIZATION, header::COOKIE]) .sensitive_headers([header::AUTHORIZATION, header::COOKIE])
.layer(TraceLayer::new_for_http().make_span_with(SpanMaker)) .layer(TraceLayer::new_for_http().make_span_with(SpanMaker))
.compression(); .compression();
// Build route service
Router::new() Router::new()
.route( .route(
"/files", "/files",
get(file::download::download) get(file::download::download)
.post(file::upload::upload) .post(file::upload::upload.layer(body_limit.clone()))
.delete(file::delete::delete) .delete(file::delete::delete)
.patch(file::modify::modify), .patch(file::modify::modify.layer(body_limit)),
) )
.route( .route(
"/folders", "/folders",