From 3ed30f0d9b8747b7cf79b9f1cf6f16ae40246929 Mon Sep 17 00:00:00 2001 From: StNicolay Date: Sat, 27 Jul 2024 19:20:07 +0300 Subject: [PATCH] Added file updates --- src/db/file.rs | 12 +++++++++ src/endpoints/file/mod.rs | 1 + src/endpoints/file/modify.rs | 51 ++++++++++++++++++++++++++++++++++++ src/endpoints/file/upload.rs | 35 +++++++------------------ src/file_storage.rs | 43 +++++++++++++++++++++++++++++- src/main.rs | 3 ++- 6 files changed, 117 insertions(+), 28 deletions(-) create mode 100644 src/endpoints/file/modify.rs diff --git a/src/db/file.rs b/src/db/file.rs index b5b32e6..62058b3 100644 --- a/src/db/file.rs +++ b/src/db/file.rs @@ -18,6 +18,18 @@ pub async fn insert( .map(|_| ()) } +pub async fn update(file_id: Uuid, size: i64, hash: Vec, pool: &Pool) -> sqlx::Result<()> { + sqlx::query!( + "UPDATE files SET (sha512, file_size, updated_at) = ($2, $3, NOW()) WHERE file_id = $1", + file_id, + hash, + size + ) + .execute(pool) + .await + .map(|_| ()) +} + #[derive(Debug, serde::Serialize)] #[allow(clippy::struct_field_names, clippy::module_name_repetitions)] pub struct FileWithoutParentId { diff --git a/src/endpoints/file/mod.rs b/src/endpoints/file/mod.rs index 690793d..c6f9922 100644 --- a/src/endpoints/file/mod.rs +++ b/src/endpoints/file/mod.rs @@ -1,3 +1,4 @@ pub mod delete; pub mod download; +pub mod modify; pub mod upload; diff --git a/src/endpoints/file/modify.rs b/src/endpoints/file/modify.rs new file mode 100644 index 0000000..08852df --- /dev/null +++ b/src/endpoints/file/modify.rs @@ -0,0 +1,51 @@ +use axum::extract::Multipart; + +use crate::prelude::*; + +#[derive(Deserialize, Debug)] +pub struct Params { + file_id: Uuid, +} + +pub async fn modify( + Query(params): Query, + State(state): State, + claims: Claims, + mut multipart: Multipart, +) -> Result { + db::file::get_permissions(params.file_id, claims.user_id, &state.pool) + .await + .handle_internal()? + .can_write_guard()?; + + // Very weird work around + let mut field = loop { + match multipart.next_field().await { + Ok(Some(field)) if field.file_name().is_some() => break field, + Ok(Some(_)) => continue, + _ => return Err(StatusCode::BAD_REQUEST), + } + }; + + let Some(mut file) = state + .storage + .write(params.file_id) + .await + .handle_internal()? + else { + return Err(StatusCode::NOT_FOUND); + }; + + let (hash, size) = match crate::FileStorage::write_to_file(&mut file, &mut field).await { + Ok(values) => values, + Err(err) => { + tracing::warn!(%err); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + }; + db::file::update(params.file_id, size, hash, &state.pool) + .await + .handle_internal()?; + + Ok(StatusCode::NO_CONTENT) +} diff --git a/src/endpoints/file/upload.rs b/src/endpoints/file/upload.rs index cade582..70de9f2 100644 --- a/src/endpoints/file/upload.rs +++ b/src/endpoints/file/upload.rs @@ -1,10 +1,7 @@ use std::collections::HashMap; -use std::io; use axum::extract::multipart::{self, Multipart}; -use sha2::Digest as _; -use tokio::io::{AsyncWrite, BufWriter}; -use tokio_util::io::StreamReader; +use tokio::io::AsyncWrite; use crate::prelude::*; @@ -13,7 +10,7 @@ pub struct Params { parent_folder: Uuid, } -async fn write_file( +async fn create_file( file_id: Uuid, file: impl AsyncWrite + Unpin, file_name: &str, @@ -21,27 +18,13 @@ async fn write_file( parent_folder: Uuid, pool: &Pool, ) -> bool { - const BUF_CAP: usize = 64 * 1024 * 1024; // 64 MiB - let mut hash = sha2::Sha512::new(); - let mut size: i64 = 0; - let stream = field.map(|value| match value { - Ok(bytes) => { - hash.update(&bytes); - size = i64::try_from(bytes.len()) - .ok() - .and_then(|part_size| size.checked_add(part_size)) - .ok_or_else(|| io::Error::other(anyhow::anyhow!("Size calculation overflow")))?; - Ok(bytes) + let (hash, size) = match crate::FileStorage::write_to_file(file, field).await { + Ok(values) => values, + Err(err) => { + tracing::warn!(%err); + return false; } - Err(err) => Err(io::Error::other(err)), - }); - let mut reader = StreamReader::new(stream); - let mut writer = BufWriter::with_capacity(BUF_CAP, file); - if let Err(err) = tokio::io::copy(&mut reader, &mut writer).await { - tracing::warn!(%err); - return false; - } - let hash = hash.finalize().to_vec(); + }; db::file::insert(file_id, parent_folder, file_name, size, hash, pool) .await .inspect_err(|err| tracing::warn!(%err)) @@ -74,7 +57,7 @@ pub async fn upload( tracing::warn!("Couldn't create uuid for new file"); continue; }; - let is_success = write_file( + let is_success = create_file( file_id, &mut file, &file_name, diff --git a/src/file_storage.rs b/src/file_storage.rs index 0e96f85..cdcd4e4 100644 --- a/src/file_storage.rs +++ b/src/file_storage.rs @@ -4,7 +4,14 @@ use std::{ sync::Arc, }; -use tokio::fs; +use axum::body::Bytes; +use futures::{Stream, StreamExt}; +use sha2::Digest as _; +use tokio::{ + fs, + io::{AsyncWrite, AsyncWriteExt, BufWriter}, +}; +use tokio_util::io::StreamReader; use uuid::Uuid; #[derive(Clone)] @@ -63,6 +70,14 @@ impl FileStorage { .map_err(Into::into) } + pub async fn write(&self, file_id: Uuid) -> anyhow::Result> { + match fs::File::create(self.path_for_file(file_id)).await { + Ok(file) => Ok(Some(file)), + Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None), + Err(err) => Err(err.into()), + } + } + pub async fn delete(&self, file_id: Uuid) -> anyhow::Result { match fs::remove_file(self.path_for_file(file_id)).await { Ok(()) => Ok(true), @@ -70,4 +85,30 @@ impl FileStorage { Err(err) => Err(err.into()), } } + + pub async fn write_to_file(file: F, stream: S) -> io::Result<(Vec, i64)> + where + F: AsyncWrite + Unpin, + S: Stream> + Unpin, + E: Into>, + { + const BUF_CAP: usize = 64 * 1024 * 1024; // 64 MiB + let mut hash = sha2::Sha512::new(); + let mut size: i64 = 0; + let stream = stream.map(|value| { + let bytes = value.map_err(io::Error::other)?; + hash.update(&bytes); + size = i64::try_from(bytes.len()) + .ok() + .and_then(|part_size| size.checked_add(part_size)) + .ok_or_else(|| io::Error::other(anyhow::anyhow!("Size calculation overflow")))?; + io::Result::Ok(bytes) + }); + let mut reader = StreamReader::new(stream); + let mut writer = BufWriter::with_capacity(BUF_CAP, file); + tokio::io::copy_buf(&mut reader, &mut writer).await?; + writer.flush().await?; + let hash = hash.finalize().to_vec(); + Ok((hash, size)) + } } diff --git a/src/main.rs b/src/main.rs index d2a7560..035e824 100644 --- a/src/main.rs +++ b/src/main.rs @@ -101,7 +101,8 @@ fn app(state: AppState) -> Router { "/files", get(file::download::download) .post(file::upload::upload) - .delete(file::delete::delete), + .delete(file::delete::delete) + .patch(file::modify::modify), ) .route( "/folders",