Now checking that user_id from claims exists

This commit is contained in:
StNicolay 2024-08-03 20:15:08 +03:00
parent 9f36d8e663
commit c4ff602ec7
Signed by: StNicolay
GPG Key ID: 9693D04DCD962B0D
3 changed files with 36 additions and 10 deletions

View File

@ -1,7 +1,7 @@
use std::{array::TryFromSliceError, sync::LazyLock}; use std::{array::TryFromSliceError, sync::LazyLock};
use axum::{ use axum::{
extract::FromRequestParts, extract::{FromRef, FromRequestParts},
http::{request::Parts, StatusCode}, http::{request::Parts, StatusCode},
response::IntoResponse, response::IntoResponse,
RequestPartsExt, RequestPartsExt,
@ -15,7 +15,7 @@ use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use crate::{db, Pool}; use crate::{db, errors::handle_error, Pool};
pub const HASH_LENGTH: usize = 64; pub const HASH_LENGTH: usize = 64;
pub const SALT_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64;
@ -141,6 +141,7 @@ impl Claims {
pub enum Error { pub enum Error {
WrongCredentials, WrongCredentials,
TokenCreation, TokenCreation,
Validation,
InvalidToken, InvalidToken,
} }
@ -149,6 +150,7 @@ impl IntoResponse for Error {
let (status, error_message) = match self { let (status, error_message) = match self {
Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
Error::Validation => (StatusCode::INTERNAL_SERVER_ERROR, "Token validation error"),
Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
}; };
(status, error_message).into_response() (status, error_message).into_response()
@ -156,17 +158,30 @@ impl IntoResponse for Error {
} }
#[axum::async_trait] #[axum::async_trait]
impl<T> FromRequestParts<T> for Claims { impl<T> FromRequestParts<T> for Claims
where
Pool: FromRef<T>,
T: Sync,
{
type Rejection = Error; type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &T) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &T) -> Result<Self, Self::Rejection> {
let pool = Pool::from_ref(state);
let TypedHeader(Authorization(bearer)) = parts let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>() .extract::<TypedHeader<Authorization<Bearer>>>()
.await .await
.map_err(|_| Error::InvalidToken)?; .map_err(|_| Error::InvalidToken)?;
let token_data = decode(bearer.token(), &KEYS.decoding_key, &Validation::default()) let claims: Claims = decode(bearer.token(), &KEYS.decoding_key, &Validation::default())
.map_err(|_| Error::InvalidToken)?; .map_err(|_| Error::InvalidToken)?
Ok(token_data.claims) .claims;
match db::users::exists(claims.user_id, &pool).await {
Ok(true) => Ok(claims),
Ok(false) => Err(Error::WrongCredentials),
Err(err) => {
handle_error(err);
Err(Error::Validation)
}
}
} }
} }

View File

@ -61,13 +61,23 @@ pub async fn update(
.await .await
} }
pub async fn get(user_id: i32, pool: &Pool) -> sqlx::Result<UserInfo> { pub async fn exists(user_id: i32, pool: &Pool) -> sqlx::Result<bool> {
sqlx::query!(
"SELECT EXISTS(SELECT user_id FROM users WHERE user_id = $1)",
user_id
)
.fetch_one(pool)
.await
.map(|record| record.exists.unwrap_or(false))
}
pub async fn get(user_id: i32, pool: &Pool) -> sqlx::Result<Option<UserInfo>> {
sqlx::query_as!( sqlx::query_as!(
UserInfo, UserInfo,
"SELECT user_id, username, email FROM users WHERE user_id = $1", "SELECT user_id, username, email FROM users WHERE user_id = $1",
user_id user_id
) )
.fetch_one(pool) .fetch_optional(pool)
.await .await
} }

View File

@ -10,7 +10,8 @@ type Response = Result<Json<db::users::UserInfo>, StatusCode>;
pub async fn get(State(pool): State<Pool>, Query(params): Query<Params>) -> Response { pub async fn get(State(pool): State<Pool>, Query(params): Query<Params>) -> Response {
let info = db::users::get(params.user_id, &pool) let info = db::users::get(params.user_id, &pool)
.await .await
.handle_internal()?; .handle_internal()?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(info)) Ok(Json(info))
} }