This commit is contained in:
2024-08-03 15:49:40 +03:00
parent f6ed06de48
commit 9217ae46cb
15 changed files with 350 additions and 549 deletions

View File

@ -1,26 +1,173 @@
use std::{array::TryFromSliceError, sync::LazyLock};
use axum::{
extract::{FromRequestParts, Query},
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::IntoResponse,
RequestPartsExt,
};
use serde::Deserialize;
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
#[derive(Deserialize, Debug)]
use crate::{db, Pool};
pub const HASH_LENGTH: usize = 64;
pub const SALT_LENGTH: usize = 64;
static PARAMS: LazyLock<scrypt::Params> =
LazyLock::new(|| scrypt::Params::new(14, 8, 1, HASH_LENGTH).unwrap());
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::from_secret(secret.as_bytes())
});
struct Keys {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl Keys {
fn from_secret(secret: &[u8]) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
}
}
}
/// Forces the evaluation of the keys. They will be created upon first use otherwise
pub fn force_init_keys() {
LazyLock::force(&KEYS);
}
/// Hashes the bytes with Scrypt with the given salt
#[must_use]
fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] {
let mut hash = [0; HASH_LENGTH];
scrypt::scrypt(bytes, salt, &PARAMS, &mut hash).unwrap();
hash
}
/// Verifieble scrypt hashed bytes
pub struct HashedBytes {
pub hash: [u8; HASH_LENGTH],
pub salt: [u8; SALT_LENGTH],
}
impl HashedBytes {
/// Hashes the bytes
#[must_use]
pub fn hash_bytes(bytes: &[u8]) -> Self {
let mut salt = [0; 64];
OsRng.fill_bytes(&mut salt);
Self {
hash: hash_scrypt(bytes, &salt),
salt,
}
}
/// Parses the bytes where the first `HASH_LENGTH` bytes are the hash and the latter `SALT_LENGTH` bytes are the salt
pub fn from_bytes(bytes: &[u8]) -> Result<Self, TryFromSliceError> {
let (hash, salt) = bytes.split_at(HASH_LENGTH);
let result = Self {
hash: hash.try_into()?,
salt: salt.try_into()?,
};
Ok(result)
}
#[must_use]
pub fn verify(&self, bytes: &[u8]) -> bool {
let hash = hash_scrypt(bytes, self.salt.as_ref());
hash.ct_eq(self.hash.as_ref()).into()
}
pub fn as_bytes(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(self.hash.len() + self.salt.len());
result.extend_from_slice(&self.hash);
result.extend_from_slice(&self.salt);
result
}
}
pub async fn authenticate_user(
username: &str,
password: &str,
pool: &Pool,
) -> anyhow::Result<Option<i32>> {
let Some((user_id, hash)) = db::users::get_hash(username, pool).await? else {
return Ok(None);
};
let hash = HashedBytes::from_bytes(&hash)?;
Ok(hash.verify(password.as_bytes()).then_some(user_id))
}
#[derive(Debug, Serialize)]
pub struct Token {
access_token: String,
token_type: &'static str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Claims {
pub user_id: i32,
pub exp: i64,
}
impl Claims {
pub fn encode(self) -> Result<Token, Error> {
let access_token = encode(
&Header::new(jsonwebtoken::Algorithm::HS256),
&self,
&KEYS.encoding_key,
)
.map_err(|_| Error::TokenCreation)?;
let token = Token {
access_token,
token_type: "Bearer",
};
Ok(token)
}
}
#[derive(Debug)]
pub enum Error {
WrongCredentials,
TokenCreation,
InvalidToken,
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let (status, error_message) = match self {
Error::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
Error::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
Error::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
(status, error_message).into_response()
}
}
#[axum::async_trait]
impl<T> FromRequestParts<T> for Claims {
type Rejection = StatusCode;
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &T) -> Result<Self, Self::Rejection> {
match parts.extract().await {
Ok(Query(claims)) => Ok(claims),
Err(err) => {
tracing::debug!(%err, "Autharization failed");
Err(StatusCode::UNAUTHORIZED)
}
}
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| Error::InvalidToken)?;
// Decode the user data
let token_data =
decode::<Claims>(bearer.token(), &KEYS.decoding_key, &Validation::default())
.map_err(|_| Error::InvalidToken)?;
Ok(token_data.claims)
}
}

View File

@ -1,3 +1,4 @@
use futures::Stream;
use uuid::Uuid;
use crate::Pool;
@ -41,10 +42,12 @@ pub struct FileWithoutParentId {
pub updated_at: chrono::NaiveDateTime,
}
pub async fn get_files(folder_id: Uuid, pool: &Pool) -> sqlx::Result<Vec<FileWithoutParentId>> {
pub fn get_files(
folder_id: Uuid,
pool: &Pool,
) -> impl Stream<Item = sqlx::Result<FileWithoutParentId>> + '_ {
sqlx::query_as!(FileWithoutParentId, r#"SELECT file_id, file_name, file_size, encode(sha512, 'base64') as "sha512!", created_at, updated_at FROM files WHERE folder_id = $1"#, folder_id)
.fetch_all(pool)
.await
.fetch(pool)
}
async fn get_folder_id(file_id: Uuid, pool: &Pool) -> sqlx::Result<Option<Uuid>> {

View File

@ -5,11 +5,17 @@ use uuid::Uuid;
use crate::Pool;
/// Creates user and returns its id
pub async fn create_user(user_name: &str, user_email: &str, pool: &Pool) -> sqlx::Result<i32> {
pub async fn create_user(
user_name: &str,
user_email: &str,
hashed_password: &[u8],
pool: &Pool,
) -> sqlx::Result<i32> {
let id = sqlx::query!(
"INSERT INTO users(username, email) VALUES ($1, $2) RETURNING user_id",
"INSERT INTO users(username, email, hashed_password) VALUES ($1, $2, $3) RETURNING user_id",
user_name,
user_email
user_email,
hashed_password
)
.fetch_one(pool)
.await?
@ -46,7 +52,7 @@ pub async fn update(
) -> sqlx::Result<UserInfo> {
sqlx::query_as!(
UserInfo,
"UPDATE users SET username = $2, email = $3 WHERE user_id = $1 RETURNING *",
"UPDATE users SET username = $2, email = $3 WHERE user_id = $1 RETURNING user_id, username, email",
user_id,
username,
email
@ -65,6 +71,17 @@ pub async fn get(user_id: i32, pool: &Pool) -> sqlx::Result<UserInfo> {
.await
}
/// Gets the hashed password field by either the email or th username
pub async fn get_hash(search_string: &str, pool: &Pool) -> sqlx::Result<Option<(i32, Vec<u8>)>> {
let record = sqlx::query!(
"SELECT user_id, hashed_password FROM users WHERE username = $1 OR email = $1",
search_string
)
.fetch_optional(pool)
.await?;
Ok(record.map(|record| (record.user_id, record.hashed_password)))
}
pub fn search_for_user<'a>(
search_string: &str,
pool: &'a Pool,

View File

@ -0,0 +1,34 @@
use chrono::TimeDelta;
use crate::{
auth::{authenticate_user, Error, Token},
prelude::*,
};
#[derive(Deserialize, Debug)]
pub struct Params {
username: String,
password: String,
}
fn get_exp() -> i64 {
let mut time = chrono::Utc::now();
time += TimeDelta::minutes(30);
time.timestamp()
}
pub async fn post(
State(state): State<AppState>,
Json(payload): Json<Params>,
) -> Result<Json<Token>, Error> {
let user_id = authenticate_user(&payload.username, &payload.password, &state.pool)
.await
.map_err(|_| Error::WrongCredentials)?
.ok_or(Error::WrongCredentials)?;
Claims {
user_id,
exp: get_exp(),
}
.encode()
.map(Json)
}

View File

@ -0,0 +1 @@
pub mod auth_post;

View File

@ -45,10 +45,10 @@ pub async fn structure(
folder_id,
structure: folder.into(),
};
let mut stack: Vec<&mut FolderStructure> = vec![&mut response.structure];
let mut stack = vec![&mut response.structure];
while let Some(folder) = stack.pop() {
let (files, folders) = try_join!(
db::file::get_files(folder_id, &pool),
db::file::get_files(folder_id, &pool).try_collect(),
db::folder::get_folders(folder_id, claims.user_id, &pool)
.map_ok(Into::into)
.try_collect()

View File

@ -26,7 +26,7 @@ pub async fn list(
.ok_or(StatusCode::NOT_FOUND)?;
let (files, folders) = try_join!(
db::file::get_files(folder_id, &pool),
db::file::get_files(folder_id, &pool).try_collect(),
db::folder::get_folders(folder_id, claims.user_id, &pool).try_collect()
)
.handle_internal()?;

View File

@ -1,3 +1,4 @@
pub mod authorization;
pub mod file;
pub mod folder;
pub mod permissions;

View File

@ -1,17 +0,0 @@
use crate::prelude::*;
#[derive(Deserialize, Debug)]
pub struct Params {
username: String,
email: String,
}
pub async fn create(
State(pool): State<Pool>,
Json(params): Json<Params>,
) -> Result<Json<i32>, StatusCode> {
let id = db::users::create_user(&params.username, &params.email, &pool)
.await
.handle_internal()?;
Ok(Json(id))
}

View File

@ -1,4 +1,3 @@
pub mod create;
pub mod delete;
pub mod get;
pub mod put;

View File

@ -7,7 +7,8 @@ mod prelude;
use std::{env, net::Ipv4Addr};
use axum::{extract::FromRef, Router};
use auth::HashedBytes;
use axum::{extract::FromRef, routing::post, Router};
use file_storage::FileStorage;
use tokio::net::TcpListener;
@ -28,10 +29,12 @@ async fn create_test_users(pool: &Pool) -> anyhow::Result<()> {
if count > 0 {
return Ok(());
}
let hash1 = HashedBytes::hash_bytes(b"Password1").as_bytes();
let hash2 = HashedBytes::hash_bytes(b"Password2").as_bytes();
tokio::try_join!(
db::users::create_user("Test1", "test1@example.com", pool),
db::users::create_user("Test2", "test2@example.com", pool)
db::users::create_user("Test1", "test1@example.com", &hash1, pool),
db::users::create_user("Test2", "test2@example.com", &hash2, pool)
)?;
Ok(())
@ -44,6 +47,8 @@ async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
auth::force_init_keys();
let pool = match env::var("DATABASE_URL") {
Ok(url) => Pool::connect(&url).await?,
Err(err) => anyhow::bail!("Error getting database url: {err}"),
@ -70,7 +75,7 @@ async fn main() -> anyhow::Result<()> {
fn app(state: AppState) -> Router {
use axum::{http::header, routing::get};
use endpoints::{
file, folder,
authorization, file, folder,
permissions::{self, get_top_level::get_top_level},
users,
};
@ -112,11 +117,11 @@ fn app(state: AppState) -> Router {
.route(
"/users",
get(users::get::get)
.post(users::create::create)
.delete(users::delete::delete)
.put(users::put::put),
)
.route("/users/search", get(users::search::search))
.route("/authorize", post(authorization::auth_post::post))
.layer(middleware)
.with_state(state)
}