From 75afab933d5f93dbf353e6280ce89a59298c817b Mon Sep 17 00:00:00 2001 From: StNicolay Date: Tue, 6 Aug 2024 16:00:38 +0300 Subject: [PATCH] More error handling improvements --- compose.yaml | 1 + src/auth.rs | 74 +++++++++++---------------- src/db/permissions.rs | 7 +-- src/endpoints/file/download.rs | 2 +- src/endpoints/file/modify.rs | 2 +- src/endpoints/folder/get_structure.rs | 4 +- src/endpoints/permissions/set.rs | 2 +- src/endpoints/users/login.rs | 11 ++-- src/endpoints/users/put.rs | 2 +- src/endpoints/users/register.rs | 19 +++---- src/errors.rs | 43 ++++++++++++---- src/prelude.rs | 5 +- 12 files changed, 87 insertions(+), 85 deletions(-) diff --git a/compose.yaml b/compose.yaml index eccc763..bf29e94 100644 --- a/compose.yaml +++ b/compose.yaml @@ -9,6 +9,7 @@ services: - 5432:5432 volumes: - postgres_data:/var/lib/postgresql/data + restart: unless-stopped volumes: postgres_data: diff --git a/src/auth.rs b/src/auth.rs index 56218b2..89133a1 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -2,8 +2,7 @@ use std::{array::TryFromSliceError, sync::LazyLock}; use axum::{ extract::{FromRef, FromRequestParts}, - http::{request::Parts, StatusCode}, - response::IntoResponse, + http::request::Parts, RequestPartsExt, }; use axum_extra::{ @@ -16,7 +15,7 @@ use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; -use crate::{db, Pool}; +use crate::prelude::*; pub const HASH_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64; @@ -56,7 +55,7 @@ fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] { } /// Verifieble scrypt hashed bytes -#[cfg_attr(test, derive(PartialEq))] +#[cfg_attr(test, derive(PartialEq))] // == OPERATOR MUSTN'T BE USED OUTSIZE OF TESTS pub struct HashedBytes { pub hash: [u8; HASH_LENGTH], pub salt: [u8; SALT_LENGTH], @@ -122,6 +121,8 @@ pub struct Claims { pub exp: i64, } +const JWT_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::HS256; + impl Claims { pub fn new(user_id: i32) -> Self { Self { @@ -130,13 +131,9 @@ impl Claims { } } - pub fn encode(self) -> Result { - let access_token = encode( - &Header::new(jsonwebtoken::Algorithm::HS256), - &self, - &KEYS.encoding_key, - ) - .map_err(|_| Error::TokenCreation)?; + pub fn encode(self) -> Result { + let access_token = encode(&Header::new(JWT_ALGORITHM), &self, &KEYS.encoding_key) + .handle_internal("Token creation error")?; let token = Token { access_token, token_type: "Bearer", @@ -145,51 +142,40 @@ impl Claims { } } -#[derive(Debug)] -pub enum Error { - WrongCredentials, - TokenCreation, - Validation, - InvalidToken, -} - -impl IntoResponse for Error { - fn into_response(self) -> axum::response::Response { - let (status, error_message) = match self { - Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), - Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), - Error::Validation => (StatusCode::INTERNAL_SERVER_ERROR, "Token validation error"), - Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), - }; - (status, error_message).into_response() - } -} - #[axum::async_trait] impl FromRequestParts for Claims where Pool: FromRef, T: Sync, { - type Rejection = Error; + type Rejection = GeneralError; async fn from_request_parts(parts: &mut Parts, state: &T) -> Result { + const INVALID_TOKEN: GeneralError = + GeneralError::const_message(StatusCode::UNAUTHORIZED, "Invalid token"); + let pool = Pool::from_ref(state); let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await - .map_err(|_| Error::InvalidToken)?; - let claims: Claims = decode(bearer.token(), &KEYS.decoding_key, &Validation::default()) - .map_err(|_| Error::InvalidToken)? - .claims; - match db::users::exists(claims.user_id, &pool).await { - Ok(true) => Ok(claims), - Ok(false) => Err(Error::WrongCredentials), - Err(err) => { - tracing::error!(%err); - Err(Error::Validation) - } - } + .map_err(|_| INVALID_TOKEN)?; + + let claims: Claims = decode( + bearer.token(), + &KEYS.decoding_key, + &Validation::new(JWT_ALGORITHM), + ) + .map_err(|_| INVALID_TOKEN)? + .claims; + + db::users::exists(claims.user_id, &pool) + .await + .handle_internal("Token validation error")? + .then_some(claims) + .ok_or(GeneralError::const_message( + StatusCode::UNAUTHORIZED, + "Wrong credentials", + )) } } diff --git a/src/db/permissions.rs b/src/db/permissions.rs index 12630bd..65836fc 100644 --- a/src/db/permissions.rs +++ b/src/db/permissions.rs @@ -38,12 +38,7 @@ impl PermissionType { } fn can_read_guard(self) -> GeneralResult<()> { - if !self.can_read() { - return Err(GeneralError::message( - StatusCode::NOT_FOUND, - "Item not found", - )); - } + self.can_read().then_some(()).item_not_found()?; Ok(()) } diff --git a/src/endpoints/file/download.rs b/src/endpoints/file/download.rs index 405fea6..71766d0 100644 --- a/src/endpoints/file/download.rs +++ b/src/endpoints/file/download.rs @@ -20,7 +20,7 @@ pub async fn download( let mut name = db::file::get_name(params.file_id, &state.pool) .await .handle_internal("Error getting file info")? - .ok_or_else(GeneralError::item_not_found)?; + .item_not_found()?; name = name .chars() .fold(String::with_capacity(name.len()), |mut result, char| { diff --git a/src/endpoints/file/modify.rs b/src/endpoints/file/modify.rs index dc9f754..faacf4f 100644 --- a/src/endpoints/file/modify.rs +++ b/src/endpoints/file/modify.rs @@ -36,7 +36,7 @@ pub async fn modify( .write(params.file_id) .await .handle_internal("Error writing to the file")? - .ok_or_else(GeneralError::item_not_found)?; + .item_not_found()?; let (hash, size) = crate::FileStorage::write_to_file(&mut file, &mut field) .await diff --git a/src/endpoints/folder/get_structure.rs b/src/endpoints/folder/get_structure.rs index dca16dc..da67ab2 100644 --- a/src/endpoints/folder/get_structure.rs +++ b/src/endpoints/folder/get_structure.rs @@ -29,12 +29,12 @@ pub async fn structure( let folder_id = db::folder::process_id(params.folder_id, claims.user_id, &pool) .await .handle_internal("Error processing id")? - .ok_or_else(GeneralError::item_not_found)?; + .item_not_found()?; let folder = db::folder::get_by_id(folder_id, &pool) .await .handle_internal("Error getting folder info")? - .ok_or_else(GeneralError::item_not_found)?; + .item_not_found()?; let mut response: FolderStructure = folder.into(); let mut stack = vec![&mut response]; diff --git a/src/endpoints/permissions/set.rs b/src/endpoints/permissions/set.rs index 55da2cb..768618c 100644 --- a/src/endpoints/permissions/set.rs +++ b/src/endpoints/permissions/set.rs @@ -29,7 +29,7 @@ pub async fn set( let folder_info = db::folder::get_by_id(params.folder_id, &pool) .await .handle_internal("Error getting folder info")? - .ok_or_else(GeneralError::item_not_found)?; + .item_not_found()?; if folder_info.owner_id == params.user_id { return Err(GeneralError::message( StatusCode::BAD_REQUEST, diff --git a/src/endpoints/users/login.rs b/src/endpoints/users/login.rs index 4f7c382..ba1288e 100644 --- a/src/endpoints/users/login.rs +++ b/src/endpoints/users/login.rs @@ -1,7 +1,7 @@ use axum::Form; use crate::{ - auth::{authenticate_user, Error, Token}, + auth::{authenticate_user, Token}, prelude::*, }; @@ -14,10 +14,13 @@ pub struct Params { pub async fn login( State(pool): State, Form(payload): Form, -) -> Result, Error> { +) -> GeneralResult> { let user_id = authenticate_user(&payload.username, &payload.password, &pool) .await - .map_err(|_| Error::WrongCredentials)? - .ok_or(Error::WrongCredentials)?; + .handle_internal("Error getting user from database")? + .handle( + StatusCode::NOT_FOUND, + "User with this name and password doesn't exist", + )?; Claims::new(user_id).encode().map(Json) } diff --git a/src/endpoints/users/put.rs b/src/endpoints/users/put.rs index 9c611ca..763317f 100644 --- a/src/endpoints/users/put.rs +++ b/src/endpoints/users/put.rs @@ -15,7 +15,7 @@ pub async fn put( claims: Claims, Json(params): Json, ) -> GeneralResult> { - params.validate().map_err(GeneralError::validation)?; + params.validate().handle_validation()?; db::users::update(claims.user_id, ¶ms.username, ¶ms.email, &pool) .await .handle_internal("Error updating the user") diff --git a/src/endpoints/users/register.rs b/src/endpoints/users/register.rs index a092c11..81447be 100644 --- a/src/endpoints/users/register.rs +++ b/src/endpoints/users/register.rs @@ -1,10 +1,9 @@ use axum::Form; -use axum_extra::either::Either; use itertools::Itertools; use validator::{Validate, ValidationError}; use crate::{ - auth::{Error, HashedBytes, Token}, + auth::{HashedBytes, Token}, prelude::*, }; @@ -48,23 +47,17 @@ fn validate_password(password: &str) -> Result<(), ValidationError> { pub async fn register( State(pool): State, Form(params): Form, -) -> Result, Either> { - params - .validate() - .map_err(GeneralError::validation) - .map_err(Either::E1)?; +) -> GeneralResult> { + params.validate().handle_validation()?; let password = HashedBytes::hash_bytes(params.password.as_bytes()).as_bytes(); let id = db::users::create_user(¶ms.username, ¶ms.email, &password, &pool) .await - .handle_internal("Error creating the user") - .map_err(Either::E1)? + .handle_internal("Error creating the user")? .handle( StatusCode::BAD_REQUEST, "The username or the email are taken", - ) - .map_err(Either::E1)?; + )?; - let token = Claims::new(id).encode().map_err(Either::E2)?; - Ok(Json(token)) + Claims::new(id).encode().map(Json) } diff --git a/src/errors.rs b/src/errors.rs index 5803512..74a9d69 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -4,14 +4,16 @@ use axum::{http::StatusCode, response::IntoResponse}; type BoxError = Box; +/// Common error type for the project pub struct GeneralError { + /// Response status code pub status_code: StatusCode, + /// Message to send to the user pub message: Cow<'static, str>, + /// Error to log pub error: Option, } -pub type GeneralResult = Result; - impl GeneralError { pub fn message(status_code: StatusCode, message: impl Into>) -> Self { Self { @@ -21,15 +23,10 @@ impl GeneralError { } } - #[allow(clippy::needless_pass_by_value)] - pub fn validation(error: validator::ValidationErrors) -> Self { - Self::message(StatusCode::BAD_REQUEST, error.to_string()) - } - - pub const fn item_not_found() -> Self { - GeneralError { - status_code: StatusCode::NOT_FOUND, - message: Cow::Borrowed("Item not found"), + pub const fn const_message(status_code: StatusCode, message: &'static str) -> Self { + Self { + status_code, + message: Cow::Borrowed(message), error: None, } } @@ -44,6 +41,8 @@ impl IntoResponse for GeneralError { } } +pub type GeneralResult = Result; + pub trait ErrorHandlingExt where Self: Sized, @@ -86,3 +85,25 @@ impl ErrorHandlingExt for Option { }) } } + +pub trait ItemNotFoundExt { + fn item_not_found(self) -> Result; +} + +impl ItemNotFoundExt for Option { + fn item_not_found(self) -> GeneralResult { + const ITEM_NOT_FOUND_ERROR: GeneralError = + GeneralError::const_message(StatusCode::NOT_FOUND, "Item not found"); + self.ok_or(ITEM_NOT_FOUND_ERROR) + } +} + +pub trait ValidationExt { + fn handle_validation(self) -> GeneralResult; +} + +impl ValidationExt for Result { + fn handle_validation(self) -> GeneralResult { + self.map_err(|err| GeneralError::message(StatusCode::BAD_REQUEST, err.to_string())) + } +} diff --git a/src/prelude.rs b/src/prelude.rs index c2ee2e4..92d18f5 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,7 +1,10 @@ pub(crate) use crate::{ auth::Claims, db::{self, permissions::PermissionExt as _}, - errors::{ErrorHandlingExt as _, GeneralError, GeneralResult}, + errors::{ + ErrorHandlingExt as _, GeneralError, GeneralResult, ItemNotFoundExt as _, + ValidationExt as _, + }, AppState, Pool, }; pub use axum::{