This repository has been archived on 2024-08-23. You can view files and clone it, but cannot push or open issues or pull requests.
project/src/auth.rs
2024-08-06 16:44:49 +03:00

205 lines
5.5 KiB
Rust

use std::{array::TryFromSliceError, sync::LazyLock};
use axum::{
extract::{FromRef, FromRequestParts},
http::request::Parts,
RequestPartsExt,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use chrono::{TimeDelta, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use crate::prelude::*;
pub const HASH_LENGTH: usize = 64;
pub const SALT_LENGTH: usize = 64;
static PARAMS: LazyLock<scrypt::Params> =
LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap());
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::from_secret(secret.as_bytes())
});
struct Keys {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl Keys {
fn from_secret(secret: &[u8]) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
}
}
}
/// Forces the evaluation of the keys. They will be created upon first use otherwise
pub fn force_init_keys() {
LazyLock::force(&KEYS);
}
/// Hashes the bytes using Scrypt with the given salt
#[must_use]
fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] {
let mut hash = [0; HASH_LENGTH];
scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap();
hash
}
/// Verifieble scrypt hashed bytes
#[cfg_attr(test, derive(PartialEq))] // == OPERATOR MUSTN'T BE USED OUTSIZE OF TESTS
pub struct HashedBytes {
pub hash: [u8; HASH_LENGTH],
pub salt: [u8; SALT_LENGTH],
}
impl HashedBytes {
/// Hashes the bytes
#[must_use]
pub fn hash_bytes(bytes: &[u8]) -> Self {
let mut salt = [0; SALT_LENGTH];
OsRng.fill_bytes(&mut salt);
Self {
hash: hash_scrypt(bytes, &salt),
salt,
}
}
/// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt
pub fn from_bytes(bytes: &[u8]) -> Result<Self, TryFromSliceError> {
let (hash, salt) = bytes.split_at(HASH_LENGTH);
let result = Self {
hash: hash.try_into()?,
salt: salt.try_into()?,
};
Ok(result)
}
#[must_use]
pub fn verify(&self, bytes: &[u8]) -> bool {
let hash = hash_scrypt(bytes, self.salt.as_ref());
hash.ct_eq(self.hash.as_ref()).into()
}
pub fn as_bytes(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(self.hash.len() + self.salt.len());
result.extend_from_slice(&self.hash);
result.extend_from_slice(&self.salt);
result
}
}
pub async fn authenticate_user(
username: &str,
password: &str,
pool: &Pool,
) -> anyhow::Result<Option<i32>> {
let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else {
return Ok(None);
};
let hash = HashedBytes::from_bytes(&hash)?;
Ok(hash.verify(password.as_bytes()).then_some(user_id))
}
#[derive(Debug, Serialize)]
pub struct Token {
access_token: String,
token_type: &'static str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Claims {
pub user_id: i32,
pub exp: i64,
}
const JWT_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::HS256;
impl Claims {
pub fn new(user_id: i32) -> Self {
Self {
user_id,
exp: (Utc::now() + TimeDelta::days(30)).timestamp(),
}
}
pub fn encode(self) -> Result<Token, GeneralError> {
let access_token = encode(&Header::new(JWT_ALGORITHM), &self, &KEYS.encoding_key)
.handle_internal("Token creation error")?;
let token = Token {
access_token,
token_type: "Bearer",
};
Ok(token)
}
}
#[axum::async_trait]
impl<T> FromRequestParts<T> for Claims
where
Pool: FromRef<T>,
T: Sync,
{
type Rejection = GeneralError;
async fn from_request_parts(parts: &mut Parts, state: &T) -> Result<Self, Self::Rejection> {
const INVALID_TOKEN: GeneralError =
GeneralError::const_message(StatusCode::UNAUTHORIZED, "Invalid token");
let pool = Pool::from_ref(state);
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| INVALID_TOKEN)?;
let claims: Claims = decode(
bearer.token(),
&KEYS.decoding_key,
&Validation::new(JWT_ALGORITHM),
)
.map_err(|_| INVALID_TOKEN)?
.claims;
db::users::exists(claims.user_id, &pool)
.await
.handle_internal("Token validation error")?
.then_some(claims)
.ok_or(GeneralError::const_message(
StatusCode::UNAUTHORIZED,
"Wrong credentials",
))
}
}
#[cfg(test)]
mod tests {
use super::HashedBytes;
const PASSWORD: &str = "Password12313#!#4)$*!#";
#[test]
fn test_hash_conversion() {
let bytes = HashedBytes::hash_bytes(PASSWORD.as_bytes());
let bytes2 = HashedBytes::from_bytes(&bytes.as_bytes()).unwrap();
assert!(bytes == bytes2);
}
#[test]
fn test_hash() {
assert!(HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(PASSWORD.as_bytes()));
}
#[test]
fn test_different_hash() {
assert!(!HashedBytes::hash_bytes(PASSWORD.as_bytes()).verify(b"Different Password"));
}
}