use std::{array::TryFromSliceError, sync::LazyLock}; use axum::{ extract::{FromRef, FromRequestParts}, http::request::Parts, RequestPartsExt, }; use axum_extra::{ headers::{authorization::Bearer, Authorization}, TypedHeader, }; use chrono::{TimeDelta, Utc}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; use crate::prelude::*; pub const HASH_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64; static PARAMS: LazyLock = LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap()); static KEYS: LazyLock = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::from_secret(secret.as_bytes()) }); struct Keys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl Keys { fn from_secret(secret: &[u8]) -> Self { Self { encoding_key: EncodingKey::from_secret(secret), decoding_key: DecodingKey::from_secret(secret), } } } /// Forces the evaluation of the keys. They will be created upon first use otherwise pub fn force_init_keys() { LazyLock::force(&KEYS); } /// Hashes the bytes using Scrypt with the given salt #[must_use] fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] { let mut hash = [0; HASH_LENGTH]; scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap(); hash } /// Verifieble scrypt hashed bytes #[cfg_attr(test, derive(PartialEq))] // == OPERATOR MUSTN'T BE USED OUTSIZE OF TESTS pub struct HashedBytes { pub hash: [u8; HASH_LENGTH], pub salt: [u8; SALT_LENGTH], } impl HashedBytes { /// Hashes the bytes #[must_use] pub fn hash_bytes(bytes: &[u8]) -> Self { let mut salt = [0; 64]; OsRng.fill_bytes(&mut salt); Self { hash: hash_scrypt(bytes, &salt), salt, } } /// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt pub fn from_bytes(bytes: &[u8]) -> Result { let (hash, salt) = bytes.split_at(HASH_LENGTH); let result = Self { hash: hash.try_into()?, salt: salt.try_into()?, }; Ok(result) } #[must_use] pub fn verify(&self, bytes: &[u8]) -> bool { let hash = hash_scrypt(bytes, self.salt.as_ref()); hash.ct_eq(self.hash.as_ref()).into() } pub fn as_bytes(&self) -> Vec { let mut result = Vec::with_capacity(self.hash.len() + self.salt.len()); result.extend_from_slice(&self.hash); result.extend_from_slice(&self.salt); result } } pub async fn authenticate_user( username: &str, password: &str, pool: &Pool, ) -> anyhow::Result> { let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else { return Ok(None); }; let hash = HashedBytes::from_bytes(&hash)?; Ok(hash.verify(password.as_bytes()).then_some(user_id)) } #[derive(Debug, Serialize)] pub struct Token { access_token: String, token_type: &'static str, } #[derive(Serialize, Deserialize, Debug)] pub struct Claims { pub user_id: i32, pub exp: i64, } const JWT_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::HS256; impl Claims { pub fn new(user_id: i32) -> Self { Self { user_id, exp: (Utc::now() + TimeDelta::days(30)).timestamp(), } } pub fn encode(self) -> Result { let access_token = encode(&Header::new(JWT_ALGORITHM), &self, &KEYS.encoding_key) .handle_internal("Token creation error")?; let token = Token { access_token, token_type: "Bearer", }; Ok(token) } } #[axum::async_trait] impl FromRequestParts for Claims where Pool: FromRef, T: Sync, { type Rejection = GeneralError; async fn from_request_parts(parts: &mut Parts, state: &T) -> Result { const INVALID_TOKEN: GeneralError = GeneralError::const_message(StatusCode::UNAUTHORIZED, "Invalid token"); let pool = Pool::from_ref(state); let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| INVALID_TOKEN)?; let claims: Claims = decode( bearer.token(), &KEYS.decoding_key, &Validation::new(JWT_ALGORITHM), ) .map_err(|_| INVALID_TOKEN)? .claims; db::users::exists(claims.user_id, &pool) .await .handle_internal("Token validation error")? .then_some(claims) .ok_or(GeneralError::const_message( StatusCode::UNAUTHORIZED, "Wrong credentials", )) } } #[cfg(test)] mod tests { use super::HashedBytes; const PASSWORD: &str = "Password12313#!#4)$*!#"; #[test] fn test_hash_conversion() { let bytes = HashedBytes::hash_bytes(PASSWORD.as_bytes()); let bytes2 = HashedBytes::from_bytes(&bytes.as_bytes()).unwrap(); assert!(bytes == bytes2); } #[test] fn test_hash() { assert!(HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(PASSWORD.as_bytes())); } #[test] fn test_different_hash() { assert!(!HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(b"Different Password")); } }