2024-08-03 12:49:40 +00:00
|
|
|
use std::{array::TryFromSliceError, sync::LazyLock};
|
|
|
|
|
2024-06-27 12:04:57 +00:00
|
|
|
use axum::{
|
2024-08-03 17:15:08 +00:00
|
|
|
extract::{FromRef, FromRequestParts},
|
2024-08-06 13:00:38 +00:00
|
|
|
http::request::Parts,
|
2024-06-27 12:04:57 +00:00
|
|
|
RequestPartsExt,
|
|
|
|
};
|
2024-08-03 12:49:40 +00:00
|
|
|
use axum_extra::{
|
|
|
|
headers::{authorization::Bearer, Authorization},
|
|
|
|
TypedHeader,
|
|
|
|
};
|
2024-08-04 06:48:41 +00:00
|
|
|
use chrono::{TimeDelta, Utc};
|
2024-08-03 12:49:40 +00:00
|
|
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
|
|
|
use rand::{rngs::OsRng, RngCore};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use subtle::ConstantTimeEq;
|
|
|
|
|
2024-08-06 13:00:38 +00:00
|
|
|
use crate::prelude::*;
|
2024-08-03 12:49:40 +00:00
|
|
|
|
|
|
|
pub const HASH_LENGTH: usize = 64;
|
|
|
|
pub const SALT_LENGTH: usize = 64;
|
|
|
|
|
|
|
|
static PARAMS: LazyLock<scrypt::Params> =
|
|
|
|
LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap());
|
|
|
|
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
|
|
|
|
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
|
|
|
Keys::from_secret(secret.as_bytes())
|
|
|
|
});
|
|
|
|
|
|
|
|
struct Keys {
|
|
|
|
encoding_key: EncodingKey,
|
|
|
|
decoding_key: DecodingKey,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Keys {
|
|
|
|
fn from_secret(secret: &[u8]) -> Self {
|
|
|
|
Self {
|
|
|
|
encoding_key: EncodingKey::from_secret(secret),
|
|
|
|
decoding_key: DecodingKey::from_secret(secret),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-06-27 12:04:57 +00:00
|
|
|
|
2024-08-03 12:49:40 +00:00
|
|
|
/// Forces the evaluation of the keys. They will be created upon first use otherwise
|
|
|
|
pub fn force_init_keys() {
|
|
|
|
LazyLock::force(&KEYS);
|
|
|
|
}
|
|
|
|
|
2024-08-03 13:44:34 +00:00
|
|
|
/// Hashes the bytes using Scrypt with the given salt
|
2024-08-03 12:49:40 +00:00
|
|
|
#[must_use]
|
|
|
|
fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] {
|
|
|
|
let mut hash = [0; HASH_LENGTH];
|
|
|
|
scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap();
|
|
|
|
hash
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Verifieble scrypt hashed bytes
|
2024-08-06 13:00:38 +00:00
|
|
|
#[cfg_attr(test, derive(PartialEq))] // == OPERATOR MUSTN'T BE USED OUTSIZE OF TESTS
|
2024-08-03 12:49:40 +00:00
|
|
|
pub struct HashedBytes {
|
|
|
|
pub hash: [u8; HASH_LENGTH],
|
|
|
|
pub salt: [u8; SALT_LENGTH],
|
|
|
|
}
|
|
|
|
|
|
|
|
impl HashedBytes {
|
|
|
|
/// Hashes the bytes
|
|
|
|
#[must_use]
|
|
|
|
pub fn hash_bytes(bytes: &[u8]) -> Self {
|
|
|
|
let mut salt = [0; 64];
|
|
|
|
OsRng.fill_bytes(&mut salt);
|
|
|
|
Self {
|
|
|
|
hash: hash_scrypt(bytes, &salt),
|
|
|
|
salt,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt
|
|
|
|
pub fn from_bytes(bytes: &[u8]) -> Result<Self, TryFromSliceError> {
|
|
|
|
let (hash, salt) = bytes.split_at(HASH_LENGTH);
|
|
|
|
let result = Self {
|
|
|
|
hash: hash.try_into()?,
|
|
|
|
salt: salt.try_into()?,
|
|
|
|
};
|
|
|
|
Ok(result)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[must_use]
|
|
|
|
pub fn verify(&self, bytes: &[u8]) -> bool {
|
|
|
|
let hash = hash_scrypt(bytes, self.salt.as_ref());
|
|
|
|
hash.ct_eq(self.hash.as_ref()).into()
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn as_bytes(&self) -> Vec<u8> {
|
|
|
|
let mut result = Vec::with_capacity(self.hash.len() + self.salt.len());
|
|
|
|
result.extend_from_slice(&self.hash);
|
|
|
|
result.extend_from_slice(&self.salt);
|
|
|
|
result
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn authenticate_user(
|
|
|
|
username: &str,
|
|
|
|
password: &str,
|
|
|
|
pool: &Pool,
|
|
|
|
) -> anyhow::Result<Option<i32>> {
|
|
|
|
let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else {
|
|
|
|
return Ok(None);
|
|
|
|
};
|
|
|
|
let hash = HashedBytes::from_bytes(&hash)?;
|
|
|
|
Ok(hash.verify(password.as_bytes()).then_some(user_id))
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Serialize)]
|
|
|
|
pub struct Token {
|
|
|
|
access_token: String,
|
|
|
|
token_type: &'static str,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
2024-06-27 12:04:57 +00:00
|
|
|
pub struct Claims {
|
|
|
|
pub user_id: i32,
|
2024-08-03 12:49:40 +00:00
|
|
|
pub exp: i64,
|
|
|
|
}
|
|
|
|
|
2024-08-06 13:00:38 +00:00
|
|
|
const JWT_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::HS256;
|
|
|
|
|
2024-08-03 12:49:40 +00:00
|
|
|
impl Claims {
|
2024-08-04 06:48:41 +00:00
|
|
|
pub fn new(user_id: i32) -> Self {
|
|
|
|
Self {
|
|
|
|
user_id,
|
|
|
|
exp: (Utc::now() + TimeDelta::days(30)).timestamp(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-08-06 13:00:38 +00:00
|
|
|
pub fn encode(self) -> Result<Token, GeneralError> {
|
|
|
|
let access_token = encode(&Header::new(JWT_ALGORITHM), &self, &KEYS.encoding_key)
|
|
|
|
.handle_internal("Token creation error")?;
|
2024-08-03 12:49:40 +00:00
|
|
|
let token = Token {
|
|
|
|
access_token,
|
|
|
|
token_type: "Bearer",
|
|
|
|
};
|
|
|
|
Ok(token)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-27 12:04:57 +00:00
|
|
|
#[axum::async_trait]
|
2024-08-03 17:15:08 +00:00
|
|
|
impl<T> FromRequestParts<T> for Claims
|
|
|
|
where
|
|
|
|
Pool: FromRef<T>,
|
|
|
|
T: Sync,
|
|
|
|
{
|
2024-08-06 13:00:38 +00:00
|
|
|
type Rejection = GeneralError;
|
2024-06-27 12:04:57 +00:00
|
|
|
|
2024-08-03 17:15:08 +00:00
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &T) -> Result<Self, Self::Rejection> {
|
2024-08-06 13:00:38 +00:00
|
|
|
const INVALID_TOKEN: GeneralError =
|
|
|
|
GeneralError::const_message(StatusCode::UNAUTHORIZED, "Invalid token");
|
|
|
|
|
2024-08-03 17:15:08 +00:00
|
|
|
let pool = Pool::from_ref(state);
|
2024-08-03 12:49:40 +00:00
|
|
|
let TypedHeader(Authorization(bearer)) = parts
|
|
|
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
|
|
|
.await
|
2024-08-06 13:00:38 +00:00
|
|
|
.map_err(|_| INVALID_TOKEN)?;
|
|
|
|
|
|
|
|
let claims: Claims = decode(
|
|
|
|
bearer.token(),
|
|
|
|
&KEYS.decoding_key,
|
|
|
|
&Validation::new(JWT_ALGORITHM),
|
|
|
|
)
|
|
|
|
.map_err(|_| INVALID_TOKEN)?
|
|
|
|
.claims;
|
|
|
|
|
|
|
|
db::users::exists(claims.user_id, &pool)
|
|
|
|
.await
|
|
|
|
.handle_internal("Token validation error")?
|
|
|
|
.then_some(claims)
|
|
|
|
.ok_or(GeneralError::const_message(
|
|
|
|
StatusCode::UNAUTHORIZED,
|
|
|
|
"Wrong credentials",
|
|
|
|
))
|
2024-06-27 12:04:57 +00:00
|
|
|
}
|
|
|
|
}
|
2024-08-03 13:44:34 +00:00
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::HashedBytes;
|
|
|
|
|
|
|
|
const PASSWORD: &str = "Password12313#!#4)$*!#";
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_hash_conversion() {
|
|
|
|
let bytes = HashedBytes::hash_bytes(PASSWORD.as_bytes());
|
|
|
|
let bytes2 = HashedBytes::from_bytes(&bytes.as_bytes()).unwrap();
|
|
|
|
assert!(bytes == bytes2);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_hash() {
|
|
|
|
assert!(HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(PASSWORD.as_bytes()));
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_different_hash() {
|
|
|
|
assert!(!HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(b"Different Password"));
|
|
|
|
}
|
|
|
|
}
|