Added file updates

This commit is contained in:
StNicolay 2024-07-27 19:20:07 +03:00
parent e8114c515d
commit 3ed30f0d9b
Signed by: StNicolay
GPG Key ID: 9693D04DCD962B0D
6 changed files with 117 additions and 28 deletions

View File

@ -18,6 +18,18 @@ pub async fn insert(
.map(|_| ()) .map(|_| ())
} }
pub async fn update(file_id: Uuid, size: i64, hash: Vec<u8>, pool: &Pool) -> sqlx::Result<()> {
sqlx::query!(
"UPDATE files SET (sha512, file_size, updated_at) = ($2, $3, NOW()) WHERE file_id = $1",
file_id,
hash,
size
)
.execute(pool)
.await
.map(|_| ())
}
#[derive(Debug, serde::Serialize)] #[derive(Debug, serde::Serialize)]
#[allow(clippy::struct_field_names, clippy::module_name_repetitions)] #[allow(clippy::struct_field_names, clippy::module_name_repetitions)]
pub struct FileWithoutParentId { pub struct FileWithoutParentId {

View File

@ -1,3 +1,4 @@
pub mod delete; pub mod delete;
pub mod download; pub mod download;
pub mod modify;
pub mod upload; pub mod upload;

View File

@ -0,0 +1,51 @@
use axum::extract::Multipart;
use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
file_id: Uuid,
}
pub async fn modify(
Query(params): Query<Params>,
State(state): State<AppState>,
claims: Claims,
mut multipart: Multipart,
) -> Result<StatusCode, StatusCode> {
db::file::get_permissions(params.file_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_write_guard()?;
// Very weird work around
let mut field = loop {
match multipart.next_field().await {
Ok(Some(field)) if field.file_name().is_some() => break field,
Ok(Some(_)) => continue,
_ => return Err(StatusCode::BAD_REQUEST),
}
};
let Some(mut file) = state
.storage
.write(params.file_id)
.await
.handle_internal()?
else {
return Err(StatusCode::NOT_FOUND);
};
let (hash, size) = match crate::FileStorage::write_to_file(&mut file, &mut field).await {
Ok(values) => values,
Err(err) => {
tracing::warn!(%err);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
db::file::update(params.file_id, size, hash, &state.pool)
.await
.handle_internal()?;
Ok(StatusCode::NO_CONTENT)
}

View File

@ -1,10 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::io;
use axum::extract::multipart::{self, Multipart}; use axum::extract::multipart::{self, Multipart};
use sha2::Digest as _; use tokio::io::AsyncWrite;
use tokio::io::{AsyncWrite, BufWriter};
use tokio_util::io::StreamReader;
use crate::prelude::*; use crate::prelude::*;
@ -13,7 +10,7 @@ pub struct Params {
parent_folder: Uuid, parent_folder: Uuid,
} }
async fn write_file( async fn create_file(
file_id: Uuid, file_id: Uuid,
file: impl AsyncWrite + Unpin, file: impl AsyncWrite + Unpin,
file_name: &str, file_name: &str,
@ -21,27 +18,13 @@ async fn write_file(
parent_folder: Uuid, parent_folder: Uuid,
pool: &Pool, pool: &Pool,
) -> bool { ) -> bool {
const BUF_CAP: usize = 64 * 1024 * 1024; // 64 MiB let (hash, size) = match crate::FileStorage::write_to_file(file, field).await {
let mut hash = sha2::Sha512::new(); Ok(values) => values,
let mut size: i64 = 0; Err(err) => {
let stream = field.map(|value| match value { tracing::warn!(%err);
Ok(bytes) => { return false;
hash.update(&bytes);
size = i64::try_from(bytes.len())
.ok()
.and_then(|part_size| size.checked_add(part_size))
.ok_or_else(|| io::Error::other(anyhow::anyhow!("Size calculation overflow")))?;
Ok(bytes)
} }
Err(err) => Err(io::Error::other(err)), };
});
let mut reader = StreamReader::new(stream);
let mut writer = BufWriter::with_capacity(BUF_CAP, file);
if let Err(err) = tokio::io::copy(&mut reader, &mut writer).await {
tracing::warn!(%err);
return false;
}
let hash = hash.finalize().to_vec();
db::file::insert(file_id, parent_folder, file_name, size, hash, pool) db::file::insert(file_id, parent_folder, file_name, size, hash, pool)
.await .await
.inspect_err(|err| tracing::warn!(%err)) .inspect_err(|err| tracing::warn!(%err))
@ -74,7 +57,7 @@ pub async fn upload(
tracing::warn!("Couldn't create uuid for new file"); tracing::warn!("Couldn't create uuid for new file");
continue; continue;
}; };
let is_success = write_file( let is_success = create_file(
file_id, file_id,
&mut file, &mut file,
&file_name, &file_name,

View File

@ -4,7 +4,14 @@ use std::{
sync::Arc, sync::Arc,
}; };
use tokio::fs; use axum::body::Bytes;
use futures::{Stream, StreamExt};
use sha2::Digest as _;
use tokio::{
fs,
io::{AsyncWrite, AsyncWriteExt, BufWriter},
};
use tokio_util::io::StreamReader;
use uuid::Uuid; use uuid::Uuid;
#[derive(Clone)] #[derive(Clone)]
@ -63,6 +70,14 @@ impl FileStorage {
.map_err(Into::into) .map_err(Into::into)
} }
pub async fn write(&self, file_id: Uuid) -> anyhow::Result<Option<impl tokio::io::AsyncWrite>> {
match fs::File::create(self.path_for_file(file_id)).await {
Ok(file) => Ok(Some(file)),
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(err.into()),
}
}
pub async fn delete(&self, file_id: Uuid) -> anyhow::Result<bool> { pub async fn delete(&self, file_id: Uuid) -> anyhow::Result<bool> {
match fs::remove_file(self.path_for_file(file_id)).await { match fs::remove_file(self.path_for_file(file_id)).await {
Ok(()) => Ok(true), Ok(()) => Ok(true),
@ -70,4 +85,30 @@ impl FileStorage {
Err(err) => Err(err.into()), Err(err) => Err(err.into()),
} }
} }
pub async fn write_to_file<F, S, E>(file: F, stream: S) -> io::Result<(Vec<u8>, i64)>
where
F: AsyncWrite + Unpin,
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
const BUF_CAP: usize = 64 * 1024 * 1024; // 64 MiB
let mut hash = sha2::Sha512::new();
let mut size: i64 = 0;
let stream = stream.map(|value| {
let bytes = value.map_err(io::Error::other)?;
hash.update(&bytes);
size = i64::try_from(bytes.len())
.ok()
.and_then(|part_size| size.checked_add(part_size))
.ok_or_else(|| io::Error::other(anyhow::anyhow!("Size calculation overflow")))?;
io::Result::Ok(bytes)
});
let mut reader = StreamReader::new(stream);
let mut writer = BufWriter::with_capacity(BUF_CAP, file);
tokio::io::copy_buf(&mut reader, &mut writer).await?;
writer.flush().await?;
let hash = hash.finalize().to_vec();
Ok((hash, size))
}
} }

View File

@ -101,7 +101,8 @@ fn app(state: AppState) -> Router {
"/files", "/files",
get(file::download::download) get(file::download::download)
.post(file::upload::upload) .post(file::upload::upload)
.delete(file::delete::delete), .delete(file::delete::delete)
.patch(file::modify::modify),
) )
.route( .route(
"/folders", "/folders",