use std::{array::TryFromSliceError, sync::LazyLock}; use axum::{ extract::FromRequestParts, http::{request::Parts, StatusCode}, response::IntoResponse, RequestPartsExt, }; use axum_extra::{ headers::{authorization::Bearer, Authorization}, TypedHeader, }; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; use crate::{db, Pool}; pub const HASH_LENGTH: usize = 64; pub const SALT_LENGTH: usize = 64; static PARAMS: LazyLock = LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap()); static KEYS: LazyLock = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::from_secret(secret.as_bytes()) }); struct Keys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl Keys { fn from_secret(secret: &[u8]) -> Self { Self { encoding_key: EncodingKey::from_secret(secret), decoding_key: DecodingKey::from_secret(secret), } } } /// Forces the evaluation of the keys. They will be created upon first use otherwise pub fn force_init_keys() { LazyLock::force(&KEYS); } /// Hashes the bytes with Scrypt with the given salt #[must_use] fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] { let mut hash = [0; HASH_LENGTH]; scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap(); hash } /// Verifieble scrypt hashed bytes pub struct HashedBytes { pub hash: [u8; HASH_LENGTH], pub salt: [u8; SALT_LENGTH], } impl HashedBytes { /// Hashes the bytes #[must_use] pub fn hash_bytes(bytes: &[u8]) -> Self { let mut salt = [0; 64]; OsRng.fill_bytes(&mut salt); Self { hash: hash_scrypt(bytes, &salt), salt, } } /// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt pub fn from_bytes(bytes: &[u8]) -> Result { let (hash, salt) = bytes.split_at(HASH_LENGTH); let result = Self { hash: hash.try_into()?, salt: salt.try_into()?, }; Ok(result) } #[must_use] pub fn verify(&self, bytes: &[u8]) -> bool { let hash = hash_scrypt(bytes, self.salt.as_ref()); hash.ct_eq(self.hash.as_ref()).into() } pub fn as_bytes(&self) -> Vec { let mut result = Vec::with_capacity(self.hash.len() + self.salt.len()); result.extend_from_slice(&self.hash); result.extend_from_slice(&self.salt); result } } pub async fn authenticate_user( username: &str, password: &str, pool: &Pool, ) -> anyhow::Result> { let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else { return Ok(None); }; let hash = HashedBytes::from_bytes(&hash)?; Ok(hash.verify(password.as_bytes()).then_some(user_id)) } #[derive(Debug, Serialize)] pub struct Token { access_token: String, token_type: &'static str, } #[derive(Serialize, Deserialize, Debug)] pub struct Claims { pub user_id: i32, pub exp: i64, } impl Claims { pub fn encode(self) -> Result { let access_token = encode( &Header::new(jsonwebtoken::Algorithm::HS256), &self, &KEYS.encoding_key, ) .map_err(|_| Error::TokenCreation)?; let token = Token { access_token, token_type: "Bearer", }; Ok(token) } } #[derive(Debug)] pub enum Error { WrongCredentials, TokenCreation, InvalidToken, } impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { let (status, error_message) = match self { Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; (status, error_message).into_response() } } #[axum::async_trait] impl FromRequestParts for Claims { type Rejection = Error; async fn from_request_parts(parts: &mut Parts, _state: &T) -> Result { let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| Error::InvalidToken)?; // Decode the user data let token_data = decode::(bearer.token(), &KEYS.decoding_key, &Validation::default()) .map_err(|_| Error::InvalidToken)?; Ok(token_data.claims) } }