Initial commit

This commit is contained in:
2024-06-27 15:04:57 +03:00
commit e8114c515d
40 changed files with 4180 additions and 0 deletions

31
src/auth.rs Normal file
View File

@ -0,0 +1,31 @@
use axum::{
extract::{FromRequestParts, Query},
http::{request::Parts, StatusCode},
RequestPartsExt,
};
use serde::Deserialize;
use crate::AppState;
#[derive(Deserialize, Debug)]
pub struct Claims {
pub user_id: i32,
}
#[axum::async_trait]
impl FromRequestParts<AppState> for Claims {
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut Parts,
_state: &AppState,
) -> Result<Self, Self::Rejection> {
match parts.extract().await {
Ok(Query(claims)) => Ok(claims),
Err(err) => {
tracing::debug!(%err, "Autharization failed");
Err(StatusCode::UNAUTHORIZED)
}
}
}
}

65
src/db/file.rs Normal file
View File

@ -0,0 +1,65 @@
use uuid::Uuid;
use crate::Pool;
use super::permissions::PermissionType;
pub async fn insert(
file_id: Uuid,
parent_folder: Uuid,
name: &str,
size: i64,
hash: Vec<u8>,
pool: &Pool,
) -> sqlx::Result<()> {
sqlx::query!("INSERT INTO files(file_id, folder_id, file_name, file_size, sha512) VALUES ($1, $2, $3, $4, $5)", file_id, parent_folder, name, size, hash)
.execute(pool)
.await
.map(|_| ())
}
#[derive(Debug, serde::Serialize)]
#[allow(clippy::struct_field_names, clippy::module_name_repetitions)]
pub struct FileWithoutParentId {
file_id: Uuid,
file_name: String,
file_size: i64,
sha512: Vec<u8>,
created_at: chrono::NaiveDateTime,
updated_at: chrono::NaiveDateTime,
}
pub async fn get_files(folder_id: Uuid, pool: &Pool) -> sqlx::Result<Vec<FileWithoutParentId>> {
sqlx::query_as!(FileWithoutParentId, "SELECT file_id, file_name, file_size, sha512, created_at, updated_at FROM files WHERE folder_id = $1", folder_id)
.fetch_all(pool)
.await
}
pub async fn get_permissions(
file_id: Uuid,
user_id: i32,
pool: &Pool,
) -> sqlx::Result<PermissionType> {
let record = sqlx::query!(
"SELECT file_id FROM files JOIN folders ON files.folder_id = folders.folder_id WHERE file_id = $1 AND owner_id = $2",
file_id,
user_id
)
.fetch_optional(pool)
.await?;
Ok(record.map(|_| PermissionType::Write).unwrap_or_default())
}
pub async fn get_name(file_id: Uuid, pool: &Pool) -> sqlx::Result<Option<String>> {
let record = sqlx::query!("SELECT file_name FROM files WHERE file_id = $1", file_id)
.fetch_optional(pool)
.await?;
Ok(record.map(|record| record.file_name))
}
pub async fn delete(file_id: Uuid, pool: &Pool) -> sqlx::Result<bool> {
sqlx::query!("DELETE FROM files WHERE file_id = $1", file_id)
.execute(pool)
.await
.map(|result| result.rows_affected() > 0)
}

115
src/db/folder.rs Normal file
View File

@ -0,0 +1,115 @@
use std::collections::HashSet;
use futures::TryStreamExt;
use uuid::Uuid;
use crate::Pool;
use super::permissions::PermissionType;
pub async fn get_permissions(
folder_id: Uuid,
user_id: i32,
pool: &Pool,
) -> sqlx::Result<PermissionType> {
let permission = sqlx::query!(
"SELECT folder_id FROM folders WHERE folder_id = $1 AND owner_id = $2",
folder_id,
user_id
)
.fetch_optional(pool)
.await?
.map(|_| PermissionType::Write)
.unwrap_or_default();
Ok(permission)
}
pub async fn get_names(folder_id: Uuid, pool: &Pool) -> sqlx::Result<HashSet<String>> {
sqlx::query!("SELECT folder_name as name FROM folders WHERE parent_folder_id = $1 UNION SELECT file_name as name FROM files WHERE folder_id = $1", folder_id)
.fetch(pool)
.map_ok(|record| record.name.unwrap())
.try_collect::<HashSet<String>>()
.await
}
pub async fn get_root(user_id: i32, pool: &Pool) -> sqlx::Result<Uuid> {
sqlx::query!(
"SELECT folder_id FROM folders WHERE owner_id = $1 AND parent_folder_id IS null",
user_id
)
.fetch_one(pool)
.await
.map(|row| row.folder_id)
}
pub async fn get_by_id(id: Option<Uuid>, user_id: i32, pool: &Pool) -> sqlx::Result<Option<Uuid>> {
match id {
Some(id) => get_permissions(id, user_id, pool)
.await
.map(|permissions| permissions.can_read().then_some(id)),
None => get_root(user_id, pool).await.map(Some),
}
}
#[derive(Debug, serde::Serialize)]
#[allow(clippy::struct_field_names, clippy::module_name_repetitions)]
pub struct FolderWithoutParentId {
folder_id: Uuid,
owner_id: i32,
folder_name: String,
created_at: chrono::NaiveDateTime,
}
pub async fn get_folders(
parent_folder_id: Uuid,
pool: &Pool,
) -> sqlx::Result<Vec<FolderWithoutParentId>> {
sqlx::query_as!(
FolderWithoutParentId,
"SELECT folder_id, owner_id, folder_name, created_at FROM folders WHERE parent_folder_id = $1",
parent_folder_id,
)
.fetch_all(pool)
.await
}
pub async fn exists_by_name(
parent_folder_id: Uuid,
folder_name: &str,
pool: &Pool,
) -> sqlx::Result<bool> {
sqlx::query!(
"SELECT EXISTS(SELECT folder_id FROM folders WHERE parent_folder_id = $1 AND folder_name = $2)",
parent_folder_id,
folder_name
)
.fetch_one(pool)
.await
.and_then(|row| {
row.exists.ok_or(sqlx::Error::RowNotFound)
})
}
pub async fn insert(
parent_folder_id: Uuid,
user_id: i32,
folder_name: &str,
pool: &Pool,
) -> sqlx::Result<Uuid> {
sqlx::query!("INSERT INTO folders(parent_folder_id, owner_id, folder_name) VALUES ($1, $2, $3) RETURNING folder_id",
parent_folder_id,
user_id,
folder_name
)
.fetch_one(pool)
.await
.map(|record| record.folder_id)
}
pub async fn delete(folder_id: Uuid, pool: &Pool) -> sqlx::Result<Vec<Uuid>> {
sqlx::query_file!("sql/delete_folder.sql", folder_id)
.fetch(pool)
.map_ok(|row| row.file_id)
.try_collect()
.await
}

3
src/db/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod file;
pub mod folder;
pub mod permissions;

59
src/db/permissions.rs Normal file
View File

@ -0,0 +1,59 @@
use axum::http::StatusCode;
#[derive(sqlx::Type)]
#[sqlx(type_name = "permission", rename_all = "lowercase")]
pub(super) enum PermissionRaw {
Read,
Write,
Manage,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum PermissionType {
#[default]
NoPermission = 1,
Read,
Write,
Manage,
}
impl From<Option<PermissionRaw>> for PermissionType {
fn from(value: Option<PermissionRaw>) -> PermissionType {
use PermissionRaw as PR;
match value {
Some(PR::Read) => PermissionType::Read,
Some(PR::Write) => PermissionType::Write,
Some(PR::Manage) => PermissionType::Manage,
None => PermissionType::NoPermission,
}
}
}
impl PermissionType {
pub fn can_read(self) -> bool {
self >= PermissionType::Read
}
pub fn can_read_guard(self) -> Result<(), StatusCode> {
if !self.can_read() {
return Err(StatusCode::NOT_FOUND);
}
Ok(())
}
pub fn can_write_guard(self) -> Result<(), StatusCode> {
self.can_read_guard()?;
if self < PermissionType::Write {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
pub fn can_manage_guard(self) -> Result<(), StatusCode> {
self.can_read_guard()?;
if self < PermissionType::Manage {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
}

View File

@ -0,0 +1,32 @@
pub use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
file_id: Uuid,
}
pub async fn delete(
Query(params): Query<Params>,
State(state): State<AppState>,
claims: Claims,
) -> Result<StatusCode, StatusCode> {
db::file::get_permissions(params.file_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_write_guard()?;
let deleted = db::file::delete(params.file_id, &state.pool)
.await
.handle_internal()?;
if !deleted {
return Err(StatusCode::NOT_FOUND); // Will not happen most of the time due to can write guard
}
state
.storage
.delete(params.file_id)
.await
.handle_internal()?;
Ok(StatusCode::NO_CONTENT)
}

View File

@ -0,0 +1,41 @@
use axum::{body::Body, http::header, response::IntoResponse};
use tokio_util::io::ReaderStream;
use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
file_id: Uuid,
}
pub async fn download(
Query(params): Query<Params>,
State(state): State<AppState>,
claims: Claims,
) -> Result<impl IntoResponse, StatusCode> {
db::file::get_permissions(params.file_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_read_guard()?;
let mut name = db::file::get_name(params.file_id, &state.pool)
.await
.handle_internal()?
.ok_or(StatusCode::NOT_FOUND)?;
name = name
.chars()
.fold(String::with_capacity(name.len()), |mut result, char| {
if ['\\', '"'].contains(&char) {
result.push('\\');
}
result.push(char);
result
});
let file = state.storage.read(params.file_id).await.handle_internal()?;
let body = Body::from_stream(ReaderStream::new(file));
let disposition = format!("attachment; filename=\"{name}\"");
let headers = [(header::CONTENT_DISPOSITION, disposition)];
Ok((headers, body))
}

View File

@ -0,0 +1,3 @@
pub mod delete;
pub mod download;
pub mod upload;

View File

@ -0,0 +1,94 @@
use std::collections::HashMap;
use std::io;
use axum::extract::multipart::{self, Multipart};
use sha2::Digest as _;
use tokio::io::{AsyncWrite, BufWriter};
use tokio_util::io::StreamReader;
use crate::prelude::*;
#[derive(Debug, Deserialize)]
pub struct Params {
parent_folder: Uuid,
}
async fn write_file(
file_id: Uuid,
file: impl AsyncWrite + Unpin,
file_name: &str,
field: &mut multipart::Field<'_>,
parent_folder: Uuid,
pool: &Pool,
) -> bool {
const BUF_CAP: usize = 64 * 1024 * 1024; // 64 MiB
let mut hash = sha2::Sha512::new();
let mut size: i64 = 0;
let stream = field.map(|value| match value {
Ok(bytes) => {
hash.update(&bytes);
size = i64::try_from(bytes.len())
.ok()
.and_then(|part_size| size.checked_add(part_size))
.ok_or_else(|| io::Error::other(anyhow::anyhow!("Size calculation overflow")))?;
Ok(bytes)
}
Err(err) => Err(io::Error::other(err)),
});
let mut reader = StreamReader::new(stream);
let mut writer = BufWriter::with_capacity(BUF_CAP, file);
if let Err(err) = tokio::io::copy(&mut reader, &mut writer).await {
tracing::warn!(%err);
return false;
}
let hash = hash.finalize().to_vec();
db::file::insert(file_id, parent_folder, file_name, size, hash, pool)
.await
.inspect_err(|err| tracing::warn!(%err))
.is_ok()
}
pub async fn upload(
Query(params): Query<Params>,
State(state): State<AppState>,
claims: Claims,
mut multi: Multipart,
) -> Result<Json<HashMap<String, Uuid>>, StatusCode> {
db::folder::get_permissions(params.parent_folder, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_write_guard()?;
let existing_names = db::folder::get_names(params.parent_folder, &state.pool)
.await
.handle_internal()?;
let mut result = HashMap::new();
while let Ok(Some(mut field)) = multi.next_field().await {
let Some(file_name) = field.file_name().map(ToOwned::to_owned) else {
continue;
};
if existing_names.contains(&file_name) {
continue;
}
let Ok((file_id, mut file)) = state.storage.create().await else {
tracing::warn!("Couldn't create uuid for new file");
continue;
};
let is_success = write_file(
file_id,
&mut file,
&file_name,
&mut field,
params.parent_folder,
&state.pool,
)
.await;
if !is_success {
let _ = state.storage.delete(file_id).await;
continue;
}
result.insert(file_name, file_id);
}
Ok(Json(result))
}

View File

@ -0,0 +1,37 @@
use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
folder_name: String,
parent_folder_id: Uuid,
}
pub async fn create(
State(state): State<AppState>,
claims: Claims,
Json(params): Json<Params>,
) -> Result<Json<Uuid>, StatusCode> {
db::folder::get_permissions(params.parent_folder_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_write_guard()?;
let exists =
db::folder::exists_by_name(params.parent_folder_id, &params.folder_name, &state.pool)
.await
.handle_internal()?;
if exists {
return Err(StatusCode::CONFLICT);
}
let id = db::folder::insert(
params.parent_folder_id,
claims.user_id,
&params.folder_name,
&state.pool,
)
.await
.handle_internal()?;
Ok(Json(id))
}

View File

@ -0,0 +1,35 @@
use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
folder_id: Uuid,
}
pub async fn delete(
State(state): State<AppState>,
claims: Claims,
Json(params): Json<Params>,
) -> Result<(), StatusCode> {
let root = db::folder::get_root(claims.user_id, &state.pool)
.await
.handle_internal()?;
if params.folder_id == root {
return Err(StatusCode::BAD_REQUEST);
}
db::folder::get_permissions(params.folder_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.can_write_guard()?;
let files_to_delete = db::folder::delete(params.folder_id, &state.pool)
.await
.handle_internal()?;
let storage = &state.storage;
futures::stream::iter(files_to_delete)
.for_each_concurrent(5, |file| async move {
let _ = storage.delete(file).await;
})
.await;
Ok(())
}

View File

@ -0,0 +1,38 @@
use tokio::try_join;
use crate::prelude::*;
#[derive(Debug, Deserialize)]
pub struct Params {
folder_id: Option<Uuid>,
}
#[derive(Debug, Serialize)]
pub struct Response {
folder_id: Uuid,
files: Vec<db::file::FileWithoutParentId>,
folders: Vec<db::folder::FolderWithoutParentId>,
}
pub async fn list(
Query(params): Query<Params>,
State(state): State<AppState>,
claims: Claims,
) -> Result<Json<Response>, StatusCode> {
let folder_id = db::folder::get_by_id(params.folder_id, claims.user_id, &state.pool)
.await
.handle_internal()?
.ok_or(StatusCode::NOT_FOUND)?;
let (files, folders) = try_join!(
db::file::get_files(folder_id, &state.pool),
db::folder::get_folders(folder_id, &state.pool)
)
.handle_internal()?;
Ok(Json(Response {
folder_id,
files,
folders,
}))
}

View File

@ -0,0 +1,3 @@
pub mod create;
pub mod delete;
pub mod list;

2
src/endpoints/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod file;
pub mod folder;

28
src/errors.rs Normal file
View File

@ -0,0 +1,28 @@
use axum::http::StatusCode;
type BoxError = Box<dyn std::error::Error>;
pub fn handle_error(error: impl Into<BoxError>) {
let error: BoxError = error.into();
tracing::error!(error);
}
pub trait ErrorHandlingExt<T, E>
where
Self: Sized,
{
fn handle(self, code: StatusCode) -> Result<T, StatusCode>;
fn handle_internal(self) -> Result<T, StatusCode> {
self.handle(StatusCode::INTERNAL_SERVER_ERROR)
}
}
impl<T, E: Into<BoxError>> ErrorHandlingExt<T, E> for Result<T, E> {
fn handle(self, code: StatusCode) -> Result<T, StatusCode> {
self.map_err(|err| {
handle_error(err);
code
})
}
}

73
src/file_storage.rs Normal file
View File

@ -0,0 +1,73 @@
use std::{
env, io,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::fs;
use uuid::Uuid;
#[derive(Clone)]
pub struct FileStorage(Arc<Path>);
impl FileStorage {
pub fn new() -> anyhow::Result<Self> {
let var = env::var("DRIVE_STORAGE_PATH");
let path_str = match var {
Ok(ref string) => string,
Err(err) => {
tracing::info!(
%err,
"Error getting DRIVE_STORAGE_PATH variable. Defaulting to ./files"
);
"./files"
}
};
let path = Path::new(path_str);
match path.metadata() {
Ok(meta) => anyhow::ensure!(meta.is_dir(), "Expected path to a directory"),
Err(err) if err.kind() == io::ErrorKind::NotFound => {
std::fs::create_dir_all(path)?;
}
Err(err) => return Err(err.into()),
};
Ok(FileStorage(path.into()))
}
fn path_for_file(&self, file_id: Uuid) -> PathBuf {
let file_name = file_id.as_hyphenated().to_string();
self.0.join(file_name)
}
async fn create_inner(&self, file_id: Uuid) -> anyhow::Result<impl tokio::io::AsyncWrite> {
fs::File::create_new(self.path_for_file(file_id))
.await
.map_err(Into::into)
}
pub async fn create(&self) -> anyhow::Result<(Uuid, impl tokio::io::AsyncWrite)> {
let mut error = anyhow::anyhow!("Error creating a file");
for _ in 0..3 {
let file_id = Uuid::new_v4();
match self.create_inner(file_id).await {
Ok(file) => return Ok((file_id, file)),
Err(err) => error = error.context(err),
}
}
Err(error)
}
pub async fn read(&self, file_id: Uuid) -> anyhow::Result<impl tokio::io::AsyncRead> {
fs::File::open(self.path_for_file(file_id))
.await
.map_err(Into::into)
}
pub async fn delete(&self, file_id: Uuid) -> anyhow::Result<bool> {
match fs::remove_file(self.path_for_file(file_id)).await {
Ok(()) => Ok(true),
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(false),
Err(err) => Err(err.into()),
}
}
}

114
src/main.rs Normal file
View File

@ -0,0 +1,114 @@
mod auth;
mod db;
mod endpoints;
mod errors;
mod file_storage;
mod prelude;
use std::{env, net::Ipv4Addr};
use axum::{routing::get, Router};
use file_storage::FileStorage;
use tokio::net::TcpListener;
type Pool = sqlx::postgres::PgPool;
#[derive(Clone)]
struct AppState {
pool: Pool,
storage: FileStorage,
}
async fn create_user(user_name: &str, user_email: &str, pool: &Pool) -> anyhow::Result<i32> {
let id = sqlx::query!(
"INSERT INTO users(username, email) VALUES ($1, $2) RETURNING user_id",
user_name,
user_email
)
.fetch_one(pool)
.await?
.user_id;
sqlx::query!(
"INSERT INTO folders(owner_id, folder_name) VALUES ($1, $2)",
id,
"ROOT"
)
.execute(pool)
.await?;
Ok(id)
}
async fn create_debug_users(pool: &Pool) -> anyhow::Result<()> {
let count = sqlx::query!("SELECT count(user_id) FROM users")
.fetch_one(pool)
.await?
.count
.unwrap_or(0);
if count > 0 {
return Ok(());
}
tokio::try_join!(
create_user("Test1", "test1@example.com", pool),
create_user("Test2", "test2@example.com", pool)
)?;
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// TODO: add utoipa and utoipauto for swagger
let _ = dotenvy::dotenv();
tracing_subscriber::fmt::init();
let pool = match env::var("DATABASE_URL") {
Ok(url) => Pool::connect(&url).await?,
Err(err) => anyhow::bail!("Error getting database url: {err}"),
};
sqlx::migrate!().run(&pool).await?;
create_debug_users(&pool).await?;
let storage = file_storage::FileStorage::new()?;
let state = AppState { pool, storage };
let router = app(state);
let addr = (Ipv4Addr::UNSPECIFIED, 3000);
let listener = TcpListener::bind(addr).await?;
axum::serve(listener, router).await?;
Ok(())
}
fn app(state: AppState) -> Router {
use axum::http::header;
use endpoints::{file, folder};
use tower_http::ServiceBuilderExt as _;
let sensitive_headers = [header::AUTHORIZATION, header::COOKIE];
let middleware = tower::ServiceBuilder::new()
.sensitive_headers(sensitive_headers)
.trace_for_http()
.compression();
// Build route service
Router::new()
.route(
"/files",
get(file::download::download)
.post(file::upload::upload)
.delete(file::delete::delete),
)
.route(
"/folders",
get(folder::list::list)
.post(folder::create::create)
.delete(folder::delete::delete),
)
.layer(middleware)
.with_state(state)
}

8
src/prelude.rs Normal file
View File

@ -0,0 +1,8 @@
pub(crate) use crate::{auth::Claims, db, errors::ErrorHandlingExt as _, AppState, Pool};
pub use axum::{
extract::{Json, Query, State},
http::StatusCode,
};
pub use futures::StreamExt as _;
pub use serde::{Deserialize, Serialize};
pub use uuid::Uuid;