From c4ff602ec7f1a2c17bdc133e58b079c8114263b9 Mon Sep 17 00:00:00 2001 From: StNicolay Date: Sat, 3 Aug 2024 20:15:08 +0300 Subject: [PATCH] Now checking that user_id from claims exists --- src/auth.rs | 29 ++++++++++++++++++++++------- src/db/users.rs | 14 ++++++++++++-- src/endpoints/users/get.rs | 3 ++- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/auth.rs b/src/auth.rs index 6fab2ef..189540a 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,7 +1,7 @@ use std::{array::TryFromSliceError, sync::LazyLock}; use axum::{ - extract::FromRequestParts, + extract::{FromRef, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, RequestPartsExt, @@ -15,7 +15,7 @@ use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; -use crate::{db, Pool}; +use crate::{db, errors::handle_error, Pool}; pub const HASH_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64; @@ -141,6 +141,7 @@ impl Claims { pub enum Error { WrongCredentials, TokenCreation, + Validation, InvalidToken, } @@ -149,6 +150,7 @@ impl IntoResponse for Error { let (status, error_message) = match self { Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), + Error::Validation => (StatusCode::INTERNAL_SERVER_ERROR, "Token validation error"), Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; (status, error_message).into_response() @@ -156,17 +158,30 @@ impl IntoResponse for Error { } #[axum::async_trait] -impl FromRequestParts for Claims { +impl FromRequestParts for Claims +where + Pool: FromRef, + T: Sync, +{ type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, _state: &T) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &T) -> Result { + let pool = Pool::from_ref(state); let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| Error::InvalidToken)?; - let token_data = decode(bearer.token(), &KEYS.decoding_key, &Validation::default()) - .map_err(|_| Error::InvalidToken)?; - Ok(token_data.claims) + let claims: Claims = decode(bearer.token(), &KEYS.decoding_key, &Validation::default()) + .map_err(|_| Error::InvalidToken)? + .claims; + match db::users::exists(claims.user_id, &pool).await { + Ok(true) => Ok(claims), + Ok(false) => Err(Error::WrongCredentials), + Err(err) => { + handle_error(err); + Err(Error::Validation) + } + } } } diff --git a/src/db/users.rs b/src/db/users.rs index afe3b77..add0e8e 100644 --- a/src/db/users.rs +++ b/src/db/users.rs @@ -61,13 +61,23 @@ pub async fn update( .await } -pub async fn get(user_id: i32, pool: &Pool) -> sqlx::Result { +pub async fn exists(user_id: i32, pool: &Pool) -> sqlx::Result { + sqlx::query!( + "SELECT EXISTS(SELECT user_id FROM users WHERE user_id = $1)", + user_id + ) + .fetch_one(pool) + .await + .map(|record| record.exists.unwrap_or(false)) +} + +pub async fn get(user_id: i32, pool: &Pool) -> sqlx::Result> { sqlx::query_as!( UserInfo, "SELECT user_id, username, email FROM users WHERE user_id = $1", user_id ) - .fetch_one(pool) + .fetch_optional(pool) .await } diff --git a/src/endpoints/users/get.rs b/src/endpoints/users/get.rs index 9d66233..074a2f9 100644 --- a/src/endpoints/users/get.rs +++ b/src/endpoints/users/get.rs @@ -10,7 +10,8 @@ type Response = Result, StatusCode>; pub async fn get(State(pool): State, Query(params): Query) -> Response { let info = db::users::get(params.user_id, &pool) .await - .handle_internal()?; + .handle_internal()? + .ok_or(StatusCode::NOT_FOUND)?; Ok(Json(info)) }