use std::{array::TryFromSliceError, sync::LazyLock}; use axum::{ extract::{FromRef, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, RequestPartsExt, }; use axum_extra::{ headers::{authorization::Bearer, Authorization}, TypedHeader, }; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; use crate::{db, errors::handle_error, Pool}; pub const HASH_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64; static PARAMS: LazyLock = LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap()); static KEYS: LazyLock = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::from_secret(secret.as_bytes()) }); struct Keys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl Keys { fn from_secret(secret: &[u8]) -> Self { Self { encoding_key: EncodingKey::from_secret(secret), decoding_key: DecodingKey::from_secret(secret), } } } /// Forces the evaluation of the keys. They will be created upon first use otherwise pub fn force_init_keys() { LazyLock::force(&KEYS); } /// Hashes the bytes using Scrypt with the given salt #[must_use] fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] { let mut hash = [0; HASH_LENGTH]; scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap(); hash } /// Verifieble scrypt hashed bytes #[cfg_attr(test, derive(PartialEq))] pub struct HashedBytes { pub hash: [u8; HASH_LENGTH], pub salt: [u8; SALT_LENGTH], } impl HashedBytes { /// Hashes the bytes #[must_use] pub fn hash_bytes(bytes: &[u8]) -> Self { let mut salt = [0; 64]; OsRng.fill_bytes(&mut salt); Self { hash: hash_scrypt(bytes, &salt), salt, } } /// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt pub fn from_bytes(bytes: &[u8]) -> Result { let (hash, salt) = bytes.split_at(HASH_LENGTH); let result = Self { hash: hash.try_into()?, salt: salt.try_into()?, }; Ok(result) } #[must_use] pub fn verify(&self, bytes: &[u8]) -> bool { let hash = hash_scrypt(bytes, self.salt.as_ref()); hash.ct_eq(self.hash.as_ref()).into() } pub fn as_bytes(&self) -> Vec { let mut result = Vec::with_capacity(self.hash.len() + self.salt.len()); result.extend_from_slice(&self.hash); result.extend_from_slice(&self.salt); result } } pub async fn authenticate_user( username: &str, password: &str, pool: &Pool, ) -> anyhow::Result> { let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else { return Ok(None); }; let hash = HashedBytes::from_bytes(&hash)?; Ok(hash.verify(password.as_bytes()).then_some(user_id)) } #[derive(Debug, Serialize)] pub struct Token { access_token: String, token_type: &'static str, } #[derive(Serialize, Deserialize, Debug)] pub struct Claims { pub user_id: i32, pub exp: i64, } impl Claims { pub fn encode(self) -> Result { let access_token = encode( &Header::new(jsonwebtoken::Algorithm::HS256), &self, &KEYS.encoding_key, ) .map_err(|_| Error::TokenCreation)?; let token = Token { access_token, token_type: "Bearer", }; Ok(token) } } #[derive(Debug)] pub enum Error { WrongCredentials, TokenCreation, Validation, InvalidToken, } impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { let (status, error_message) = match self { Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), Error::Validation => (StatusCode::INTERNAL_SERVER_ERROR, "Token validation error"), Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; (status, error_message).into_response() } } #[axum::async_trait] impl FromRequestParts for Claims where Pool: FromRef, T: Sync, { type Rejection = Error; async fn from_request_parts(parts: &mut Parts, state: &T) -> Result { let pool = Pool::from_ref(state); let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| Error::InvalidToken)?; let claims: Claims = decode(bearer.token(), &KEYS.decoding_key, &Validation::default()) .map_err(|_| Error::InvalidToken)? .claims; match db::users::exists(claims.user_id, &pool).await { Ok(true) => Ok(claims), Ok(false) => Err(Error::WrongCredentials), Err(err) => { handle_error(err); Err(Error::Validation) } } } } #[cfg(test)] mod tests { use super::HashedBytes; const PASSWORD: &str = "Password12313#!#4)$*!#"; #[test] fn test_hash_conversion() { let bytes = HashedBytes::hash_bytes(PASSWORD.as_bytes()); let bytes2 = HashedBytes::from_bytes(&bytes.as_bytes()).unwrap(); assert!(bytes == bytes2); } #[test] fn test_hash() { assert!(HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(PASSWORD.as_bytes())); } #[test] fn test_different_hash() { assert!(!HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(b"Different Password")); } }