211 lines
5.8 KiB
Rust
211 lines
5.8 KiB
Rust
use std::{array::TryFromSliceError, sync::LazyLock};
|
|
|
|
use axum::{
|
|
extract::{FromRef, FromRequestParts},
|
|
http::{request::Parts, StatusCode},
|
|
response::IntoResponse,
|
|
RequestPartsExt,
|
|
};
|
|
use axum_extra::{
|
|
headers::{authorization::Bearer, Authorization},
|
|
TypedHeader,
|
|
};
|
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
|
use rand::{rngs::OsRng, RngCore};
|
|
use serde::{Deserialize, Serialize};
|
|
use subtle::ConstantTimeEq;
|
|
|
|
use crate::{db, errors::handle_error, Pool};
|
|
|
|
pub const HASH_LENGTH: usize = 64;
|
|
pub const SALT_LENGTH: usize = 64;
|
|
|
|
static PARAMS: LazyLock<scrypt::Params> =
|
|
LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap());
|
|
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
|
|
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
|
Keys::from_secret(secret.as_bytes())
|
|
});
|
|
|
|
struct Keys {
|
|
encoding_key: EncodingKey,
|
|
decoding_key: DecodingKey,
|
|
}
|
|
|
|
impl Keys {
|
|
fn from_secret(secret: &[u8]) -> Self {
|
|
Self {
|
|
encoding_key: EncodingKey::from_secret(secret),
|
|
decoding_key: DecodingKey::from_secret(secret),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Forces the evaluation of the keys. They will be created upon first use otherwise
|
|
pub fn force_init_keys() {
|
|
LazyLock::force(&KEYS);
|
|
}
|
|
|
|
/// Hashes the bytes using Scrypt with the given salt
|
|
#[must_use]
|
|
fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] {
|
|
let mut hash = [0; HASH_LENGTH];
|
|
scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap();
|
|
hash
|
|
}
|
|
|
|
/// Verifieble scrypt hashed bytes
|
|
#[cfg_attr(test, derive(PartialEq))]
|
|
pub struct HashedBytes {
|
|
pub hash: [u8; HASH_LENGTH],
|
|
pub salt: [u8; SALT_LENGTH],
|
|
}
|
|
|
|
impl HashedBytes {
|
|
/// Hashes the bytes
|
|
#[must_use]
|
|
pub fn hash_bytes(bytes: &[u8]) -> Self {
|
|
let mut salt = [0; 64];
|
|
OsRng.fill_bytes(&mut salt);
|
|
Self {
|
|
hash: hash_scrypt(bytes, &salt),
|
|
salt,
|
|
}
|
|
}
|
|
|
|
/// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt
|
|
pub fn from_bytes(bytes: &[u8]) -> Result<Self, TryFromSliceError> {
|
|
let (hash, salt) = bytes.split_at(HASH_LENGTH);
|
|
let result = Self {
|
|
hash: hash.try_into()?,
|
|
salt: salt.try_into()?,
|
|
};
|
|
Ok(result)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn verify(&self, bytes: &[u8]) -> bool {
|
|
let hash = hash_scrypt(bytes, self.salt.as_ref());
|
|
hash.ct_eq(self.hash.as_ref()).into()
|
|
}
|
|
|
|
pub fn as_bytes(&self) -> Vec<u8> {
|
|
let mut result = Vec::with_capacity(self.hash.len() + self.salt.len());
|
|
result.extend_from_slice(&self.hash);
|
|
result.extend_from_slice(&self.salt);
|
|
result
|
|
}
|
|
}
|
|
|
|
pub async fn authenticate_user(
|
|
username: &str,
|
|
password: &str,
|
|
pool: &Pool,
|
|
) -> anyhow::Result<Option<i32>> {
|
|
let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else {
|
|
return Ok(None);
|
|
};
|
|
let hash = HashedBytes::from_bytes(&hash)?;
|
|
Ok(hash.verify(password.as_bytes()).then_some(user_id))
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct Token {
|
|
access_token: String,
|
|
token_type: &'static str,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct Claims {
|
|
pub user_id: i32,
|
|
pub exp: i64,
|
|
}
|
|
|
|
impl Claims {
|
|
pub fn encode(self) -> Result<Token, Error> {
|
|
let access_token = encode(
|
|
&Header::new(jsonwebtoken::Algorithm::HS256),
|
|
&self,
|
|
&KEYS.encoding_key,
|
|
)
|
|
.map_err(|_| Error::TokenCreation)?;
|
|
let token = Token {
|
|
access_token,
|
|
token_type: "Bearer",
|
|
};
|
|
Ok(token)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum Error {
|
|
WrongCredentials,
|
|
TokenCreation,
|
|
Validation,
|
|
InvalidToken,
|
|
}
|
|
|
|
impl IntoResponse for Error {
|
|
fn into_response(self) -> axum::response::Response {
|
|
let (status, error_message) = match self {
|
|
Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
|
Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
|
Error::Validation => (StatusCode::INTERNAL_SERVER_ERROR, "Token validation error"),
|
|
Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
|
};
|
|
(status, error_message).into_response()
|
|
}
|
|
}
|
|
|
|
#[axum::async_trait]
|
|
impl<T> FromRequestParts<T> for Claims
|
|
where
|
|
Pool: FromRef<T>,
|
|
T: Sync,
|
|
{
|
|
type Rejection = Error;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &T) -> Result<Self, Self::Rejection> {
|
|
let pool = Pool::from_ref(state);
|
|
let TypedHeader(Authorization(bearer)) = parts
|
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
|
.await
|
|
.map_err(|_| Error::InvalidToken)?;
|
|
let claims: Claims = decode(bearer.token(), &KEYS.decoding_key, &Validation::default())
|
|
.map_err(|_| Error::InvalidToken)?
|
|
.claims;
|
|
match db::users::exists(claims.user_id, &pool).await {
|
|
Ok(true) => Ok(claims),
|
|
Ok(false) => Err(Error::WrongCredentials),
|
|
Err(err) => {
|
|
handle_error(err);
|
|
Err(Error::Validation)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::HashedBytes;
|
|
|
|
const PASSWORD: &str = "Password12313#!#4)$*!#";
|
|
|
|
#[test]
|
|
fn test_hash_conversion() {
|
|
let bytes = HashedBytes::hash_bytes(PASSWORD.as_bytes());
|
|
let bytes2 = HashedBytes::from_bytes(&bytes.as_bytes()).unwrap();
|
|
assert!(bytes == bytes2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_hash() {
|
|
assert!(HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(PASSWORD.as_bytes()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_different_hash() {
|
|
assert!(!HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(b"Different Password"));
|
|
}
|
|
}
|