Made cryptography and entity modules
Coupling was rising so it just makes sense
This commit is contained in:
@ -10,9 +10,11 @@ crate::export_handlers!(
|
||||
change_locale
|
||||
);
|
||||
|
||||
use crate::{errors::InvalidCommand, handle_error, locales::LocaleTypeExt};
|
||||
use crate::{
|
||||
entity::locale::LocaleType, errors::handle_error, errors::InvalidCommand,
|
||||
locales::LocaleTypeExt,
|
||||
};
|
||||
use base64::{engine::general_purpose::STANDARD_NO_PAD as B64_ENGINE, Engine as _};
|
||||
use entity::locale::LocaleType;
|
||||
use std::str::FromStr;
|
||||
use teloxide::types::CallbackQuery;
|
||||
|
||||
|
163
src/cryptography/account.rs
Normal file
163
src/cryptography/account.rs
Normal file
@ -0,0 +1,163 @@
|
||||
use crate::entity::account::Account;
|
||||
use chacha20poly1305::{AeadCore, AeadInPlace, ChaCha20Poly1305, KeyInit};
|
||||
use pbkdf2::pbkdf2_hmac_array;
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
use sha2::Sha256;
|
||||
|
||||
pub struct Cipher {
|
||||
chacha: ChaCha20Poly1305,
|
||||
}
|
||||
|
||||
impl Cipher {
|
||||
/// Creates a new cipher from a master password and the salt
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn new(password: &[u8], salt: &[u8]) -> Self {
|
||||
let key = pbkdf2_hmac_array::<Sha256, 32>(password, salt, 480_000);
|
||||
|
||||
Self {
|
||||
chacha: ChaCha20Poly1305::new(&key.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypts the value with the current cipher. The 12 byte nonce is appended to the result
|
||||
#[inline]
|
||||
#[allow(clippy::missing_panics_doc)]
|
||||
pub fn encrypt(&self, value: &mut Vec<u8>) {
|
||||
let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
|
||||
self.chacha.encrypt_in_place(&nonce, b"", value).unwrap();
|
||||
value.extend_from_slice(&nonce);
|
||||
}
|
||||
|
||||
/// Decrypts the value with the current cipher. The 12 byte nonce is expected to be at the end of the value
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the tag doesn't match the ciphertext
|
||||
#[inline]
|
||||
#[allow(clippy::missing_panics_doc)]
|
||||
pub fn decrypt(&self, value: &mut Vec<u8>) -> super::Result<()> {
|
||||
if value.len() <= 12 {
|
||||
return Err(super::Error::InvalidInputLength);
|
||||
}
|
||||
let nonce: [u8; 12] = value[value.len() - 12..].try_into().unwrap();
|
||||
value.truncate(value.len() - 12);
|
||||
|
||||
self.chacha
|
||||
.decrypt_in_place(nonce.as_slice().into(), b"", value)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug)]
|
||||
pub struct Decrypted {
|
||||
pub name: String,
|
||||
pub login: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl Decrypted {
|
||||
/// Constructs `DecryptedAccount` by decrypting the provided account
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the tag doesn't match the ciphertext or if the decrypted data isn't valid UTF-8
|
||||
#[inline]
|
||||
pub fn from_account(mut account: Account, master_pass: &str) -> super::Result<Self> {
|
||||
let cipher = Cipher::new(master_pass.as_bytes(), &account.salt);
|
||||
cipher.decrypt(&mut account.enc_login)?;
|
||||
cipher.decrypt(&mut account.enc_password)?;
|
||||
|
||||
Ok(Self {
|
||||
name: account.name,
|
||||
login: String::from_utf8(account.enc_login)?,
|
||||
password: String::from_utf8(account.enc_password)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Constructs `ActiveModel` with eath field Set by encrypting `self`
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn into_account(self, user_id: u64, master_pass: &str) -> Account {
|
||||
let mut enc_login = self.login.into_bytes();
|
||||
let mut enc_password = self.password.into_bytes();
|
||||
let mut salt = vec![0; 64];
|
||||
OsRng.fill_bytes(&mut salt);
|
||||
|
||||
let cipher = Cipher::new(master_pass.as_bytes(), &salt);
|
||||
cipher.encrypt(&mut enc_login);
|
||||
cipher.encrypt(&mut enc_password);
|
||||
|
||||
Account {
|
||||
user_id,
|
||||
name: self.name,
|
||||
salt,
|
||||
enc_login,
|
||||
enc_password,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the account's fields are valid
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn validate(&self) -> bool {
|
||||
[
|
||||
self.name.as_str(),
|
||||
self.login.as_str(),
|
||||
self.password.as_str(),
|
||||
]
|
||||
.into_iter()
|
||||
.all(super::validate_field)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
const TESTING_MASTER_PASSWORD: &str = "VeryStr^n#M@$terP@$$!word";
|
||||
static CIPHER: Lazy<Cipher> = Lazy::new(|| {
|
||||
let mut salt = [0; 64];
|
||||
OsRng.fill_bytes(&mut salt);
|
||||
|
||||
Cipher::new(TESTING_MASTER_PASSWORD.as_bytes(), &salt)
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn cipher_test() -> crate::cryptography::Result<()> {
|
||||
const ORIGINAL: &[u8] = b"Data to protect";
|
||||
let mut data = ORIGINAL.to_owned();
|
||||
|
||||
CIPHER.encrypt(&mut data);
|
||||
CIPHER.decrypt(&mut data)?;
|
||||
|
||||
assert_eq!(ORIGINAL, data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn account_encryption() -> crate::cryptography::Result<()> {
|
||||
let original = Decrypted {
|
||||
name: "Account Name".to_owned(),
|
||||
login: "StrongLogin@mail.com".to_owned(),
|
||||
password: "StrongP@$$word!".to_owned(),
|
||||
};
|
||||
let account = original.clone().into_account(1, TESTING_MASTER_PASSWORD);
|
||||
let decrypted = Decrypted::from_account(account, TESTING_MASTER_PASSWORD)?;
|
||||
|
||||
assert_eq!(original, decrypted);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_invalid_input_length() {
|
||||
let mut bytes = vec![0];
|
||||
|
||||
assert!(matches!(
|
||||
CIPHER.decrypt(&mut bytes),
|
||||
Err(crate::cryptography::Error::InvalidInputLength)
|
||||
));
|
||||
}
|
||||
}
|
93
src/cryptography/hashing.rs
Normal file
93
src/cryptography/hashing.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use crate::entity::master_pass::MasterPass;
|
||||
use once_cell::sync::Lazy;
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
use scrypt::{scrypt, Params};
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
pub const HASH_LENGTH: usize = 64;
|
||||
pub const SALT_LENGTH: usize = 64;
|
||||
|
||||
static PARAMS: Lazy<Params> = Lazy::new(|| Params::new(14, 8, 1, HASH_LENGTH).unwrap());
|
||||
|
||||
/// Hashes the bytes with Scrypt with the given salt
|
||||
#[inline]
|
||||
#[must_use]
|
||||
#[allow(clippy::missing_panics_doc)]
|
||||
pub fn hash_scrypt(bytes: &[u8], salt: &[u8]) -> [u8; HASH_LENGTH] {
|
||||
let mut hash = [0; HASH_LENGTH];
|
||||
scrypt(bytes, salt, &PARAMS, &mut hash).unwrap();
|
||||
hash
|
||||
}
|
||||
|
||||
/// Verifieble scrypt hashed bytes
|
||||
pub struct HashedBytes<T, U>
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
U: AsRef<[u8]>,
|
||||
{
|
||||
pub hash: T,
|
||||
pub salt: U,
|
||||
}
|
||||
|
||||
impl HashedBytes<[u8; HASH_LENGTH], [u8; SALT_LENGTH]> {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn new(bytes: &[u8]) -> Self {
|
||||
let mut salt = [0; 64];
|
||||
OsRng.fill_bytes(&mut salt);
|
||||
Self {
|
||||
hash: hash_scrypt(bytes, &salt),
|
||||
salt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> HashedBytes<T, U>
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
U: AsRef<[u8]>,
|
||||
{
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn verify(&self, bytes: &[u8]) -> bool {
|
||||
let hash = hash_scrypt(bytes, self.salt.as_ref());
|
||||
hash.ct_eq(self.hash.as_ref()).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a MasterPass> for HashedBytes<&'a [u8], &'a [u8]> {
|
||||
#[inline]
|
||||
fn from(value: &'a MasterPass) -> Self {
|
||||
HashedBytes {
|
||||
hash: &value.password_hash,
|
||||
salt: &value.salt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MasterPass> for HashedBytes<Vec<u8>, Vec<u8>> {
|
||||
fn from(value: MasterPass) -> Self {
|
||||
Self {
|
||||
hash: value.password_hash,
|
||||
salt: value.salt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn params_valid() {
|
||||
#[allow(clippy::no_effect_underscore_binding)]
|
||||
let _params: &Params = &PARAMS; // Initializes the PARAMS, which might panic if the passed in values are invalid
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hashing_test() {
|
||||
const ORIGINAL: &[u8] = b"Important data";
|
||||
|
||||
assert!(HashedBytes::new(ORIGINAL).verify(ORIGINAL));
|
||||
}
|
||||
}
|
31
src/cryptography/mod.rs
Normal file
31
src/cryptography/mod.rs
Normal file
@ -0,0 +1,31 @@
|
||||
//! Functions to encrypt the database models
|
||||
|
||||
pub mod account;
|
||||
pub mod hashing;
|
||||
pub mod passwords;
|
||||
|
||||
/// Returns true if the field is valid
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn validate_field(field: &str) -> bool {
|
||||
if field.len() > 255 {
|
||||
return false;
|
||||
}
|
||||
field
|
||||
.chars()
|
||||
.all(|char| !['`', '\\', '\n', '\t'].contains(&char))
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Invalid input length")]
|
||||
InvalidInputLength,
|
||||
|
||||
#[error(transparent)]
|
||||
ChaCha(#[from] chacha20poly1305::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
InvalidUTF8(#[from] std::string::FromUtf8Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
112
src/cryptography/passwords.rs
Normal file
112
src/cryptography/passwords.rs
Normal file
@ -0,0 +1,112 @@
|
||||
use arrayvec::ArrayString;
|
||||
use rand::{seq::SliceRandom, thread_rng, CryptoRng, Rng};
|
||||
use std::array;
|
||||
|
||||
const CHARS: &[u8] = br##"!"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]^_abcdefghijklmnopqrstuvwxyz{|}~"##;
|
||||
|
||||
bitflags::bitflags! {
|
||||
struct PasswordFlags: u8 {
|
||||
const LOWERCASE = 0b0001;
|
||||
const UPPERCASE = 0b0010;
|
||||
const NUMBER = 0b0100;
|
||||
const SPECIAL_CHARACTER = 0b1000;
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub struct PasswordValidity: u8 {
|
||||
const NO_LOWERCASE = 0b00001;
|
||||
const NO_UPPERCASE = 0b00010;
|
||||
const NO_NUMBER = 0b00100;
|
||||
const NO_SPECIAL_CHARACTER = 0b01000;
|
||||
const TOO_SHORT = 0b10000;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the generated master password is valid.
|
||||
/// It checks that it has at least one lowercase, one uppercase, one number and one punctuation char
|
||||
#[inline]
|
||||
#[must_use]
|
||||
fn check_generated_password<const LENGTH: usize>(password: &[u8; LENGTH]) -> bool {
|
||||
let mut flags = PasswordFlags::empty();
|
||||
for &byte in password {
|
||||
match byte {
|
||||
b'a'..=b'z' => flags |= PasswordFlags::LOWERCASE,
|
||||
b'A'..=b'Z' => flags |= PasswordFlags::UPPERCASE,
|
||||
b'0'..=b'9' => flags |= PasswordFlags::NUMBER,
|
||||
b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => {
|
||||
flags |= PasswordFlags::SPECIAL_CHARACTER;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
if flags.is_all() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Continuously generates the password until it passes the checks
|
||||
#[inline]
|
||||
#[must_use]
|
||||
fn generate_password<R, const LENGTH: usize>(rng: &mut R) -> ArrayString<LENGTH>
|
||||
where
|
||||
R: Rng + CryptoRng,
|
||||
{
|
||||
loop {
|
||||
let password = array::from_fn(|_| *CHARS.choose(rng).unwrap());
|
||||
if check_generated_password(&password) {
|
||||
return ArrayString::from_byte_string(&password).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub fn generate_passwords<const AMOUNT: usize, const LENGTH: usize>(
|
||||
) -> [ArrayString<LENGTH>; AMOUNT] {
|
||||
let mut rng = thread_rng();
|
||||
array::from_fn(|_| generate_password(&mut rng))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn check_master_pass(password: &str) -> PasswordValidity {
|
||||
let mut count = 0;
|
||||
let mut chars = password.chars();
|
||||
let mut flags = PasswordValidity::all();
|
||||
|
||||
for char in &mut chars {
|
||||
count += 1;
|
||||
if char.is_lowercase() {
|
||||
flags.remove(PasswordValidity::NO_LOWERCASE);
|
||||
} else if char.is_uppercase() {
|
||||
flags.remove(PasswordValidity::NO_UPPERCASE);
|
||||
} else if char.is_ascii_digit() {
|
||||
flags.remove(PasswordValidity::NO_NUMBER);
|
||||
} else if char.is_ascii_punctuation() {
|
||||
flags.remove(PasswordValidity::NO_SPECIAL_CHARACTER);
|
||||
}
|
||||
|
||||
if flags == PasswordValidity::TOO_SHORT {
|
||||
count += chars.count();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if count >= 8 {
|
||||
flags.remove(PasswordValidity::TOO_SHORT);
|
||||
}
|
||||
|
||||
flags
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CHARS;
|
||||
|
||||
#[test]
|
||||
fn chars_must_be_ascii() {
|
||||
assert!(CHARS.is_ascii());
|
||||
}
|
||||
}
|
180
src/entity/account.rs
Normal file
180
src/entity/account.rs
Normal file
@ -0,0 +1,180 @@
|
||||
use super::Pool;
|
||||
use futures::{Stream, TryStreamExt};
|
||||
use sqlx::{query, query_as, Executor, FromRow, MySql};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, FromRow, Default)]
|
||||
pub struct Account {
|
||||
pub user_id: u64,
|
||||
pub name: String,
|
||||
pub salt: Vec<u8>,
|
||||
pub enc_login: Vec<u8>,
|
||||
pub enc_password: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Account {
|
||||
// Inserts the account into DB
|
||||
#[inline]
|
||||
pub async fn insert(&self, pool: &Pool) -> super::Result<()> {
|
||||
query!(
|
||||
"INSERT INTO account VALUES (?, ?, ?, ?, ?)",
|
||||
self.user_id,
|
||||
self.name,
|
||||
self.salt,
|
||||
self.enc_login,
|
||||
self.enc_password
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
/// Gets all user's account from DB
|
||||
#[inline]
|
||||
pub fn get_all(user_id: u64, pool: &Pool) -> impl Stream<Item = super::Result<Self>> + '_ {
|
||||
query_as("SELECT * FROM account WHERE user_id = ?")
|
||||
.bind(user_id)
|
||||
.fetch(pool)
|
||||
}
|
||||
|
||||
/// Streams the names of the user accounts
|
||||
#[inline]
|
||||
pub fn get_names(user_id: u64, pool: &Pool) -> impl Stream<Item = super::Result<String>> + '_ {
|
||||
query_as::<_, (String,)>("SELECT name FROM account WHERE user_id = ? ORDER BY name")
|
||||
.bind(user_id)
|
||||
.fetch(pool)
|
||||
.map_ok(|(name,)| name)
|
||||
}
|
||||
|
||||
/// Checks if the account exists
|
||||
#[inline]
|
||||
pub async fn exists(user_id: u64, account_name: &str, pool: &Pool) -> super::Result<bool> {
|
||||
query_as::<_, (bool,)>(
|
||||
"SELECT EXISTS(SELECT * FROM account WHERE user_id = ? AND name = ? LIMIT 1) as value",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(account_name)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map(|(exists,)| exists)
|
||||
}
|
||||
|
||||
/// Gets the account from the DB
|
||||
#[inline]
|
||||
pub async fn get(user_id: u64, account_name: &str, pool: &Pool) -> super::Result<Option<Self>> {
|
||||
query_as("SELECT * FROM account WHERE user_id = ? AND name = ?")
|
||||
.bind(user_id)
|
||||
.bind(account_name)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
// Deletes the account from DB
|
||||
#[inline]
|
||||
pub async fn delete(user_id: u64, name: &str, pool: &Pool) -> super::Result<()> {
|
||||
query!(
|
||||
"DELETE FROM account WHERE user_id = ? AND name = ?",
|
||||
user_id,
|
||||
name
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
/// Deletes all the user's accounts from DB
|
||||
#[inline]
|
||||
pub async fn delete_all(
|
||||
user_id: u64,
|
||||
pool: impl Executor<'_, Database = MySql>,
|
||||
) -> super::Result<()> {
|
||||
query!("DELETE FROM account WHERE user_id = ?", user_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
/// Gets a name by a SHA256 hash of the name
|
||||
#[inline]
|
||||
pub async fn get_name_by_hash(
|
||||
user_id: u64,
|
||||
hash: &[u8],
|
||||
pool: &Pool,
|
||||
) -> super::Result<Option<String>> {
|
||||
let hash = hex::encode(hash);
|
||||
let name = query_as::<_, (String,)>(
|
||||
"SELECT `name` FROM `account` WHERE SHA2(`name`, 256) = ? AND `user_id` = ?;",
|
||||
)
|
||||
.bind(hash)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(name.map(|(name,)| name))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn get_salt(user_id: u64, name: &str, pool: &Pool) -> super::Result<Option<Vec<u8>>> {
|
||||
let salt =
|
||||
query_as::<_, (Vec<u8>,)>("SELECT salt FROM account WHERE user_id = ? AND name = ?")
|
||||
.bind(user_id)
|
||||
.bind(name)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(salt.map(|(salt,)| salt))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn update_name(
|
||||
user_id: u64,
|
||||
original_name: &str,
|
||||
new_name: &str,
|
||||
pool: &Pool,
|
||||
) -> super::Result<bool> {
|
||||
query!(
|
||||
"UPDATE account SET name = ? WHERE user_id = ? AND name = ?",
|
||||
new_name,
|
||||
user_id,
|
||||
original_name
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|result| result.rows_affected() != 0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn update_login(
|
||||
user_id: u64,
|
||||
name: &str,
|
||||
login: Vec<u8>,
|
||||
pool: &Pool,
|
||||
) -> super::Result<bool> {
|
||||
query!(
|
||||
"UPDATE account SET enc_login = ? WHERE user_id = ? AND name = ?",
|
||||
login,
|
||||
user_id,
|
||||
name
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|result| result.rows_affected() != 0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn update_password(
|
||||
user_id: u64,
|
||||
name: &str,
|
||||
password: Vec<u8>,
|
||||
pool: &Pool,
|
||||
) -> super::Result<bool> {
|
||||
query!(
|
||||
"UPDATE account SET enc_password = ? WHERE user_id = ? AND name = ?",
|
||||
password,
|
||||
user_id,
|
||||
name
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|result| result.rows_affected() != 0)
|
||||
}
|
||||
}
|
50
src/entity/locale.rs
Normal file
50
src/entity/locale.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use crate::prelude::*;
|
||||
use sqlx::{mysql::MySqlQueryResult as QueryResult, query, query_as};
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub enum LocaleType {
|
||||
#[default]
|
||||
Eng = 1,
|
||||
Ru = 2,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for LocaleType {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
1 => Ok(Self::Eng),
|
||||
2 => Ok(Self::Ru),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LocaleType> for u8 {
|
||||
fn from(value: LocaleType) -> Self {
|
||||
value as Self
|
||||
}
|
||||
}
|
||||
|
||||
impl LocaleType {
|
||||
pub async fn get_from_db(user_id: u64, db: &Pool) -> super::Result<Option<Self>> {
|
||||
let result: Option<(u8,)> = query_as("SELECT locale FROM master_pass WHERE user_id = ?")
|
||||
.bind(user_id)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
Ok(result.and_then(|val| val.0.try_into().ok()))
|
||||
}
|
||||
|
||||
pub async fn update(self, user_id: u64, db: &Pool) -> super::Result<bool> {
|
||||
let result: QueryResult = query!(
|
||||
"UPDATE master_pass SET locale = ? WHERE user_id = ?",
|
||||
u8::from(self),
|
||||
user_id
|
||||
)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected() == 1)
|
||||
}
|
||||
}
|
59
src/entity/master_pass.rs
Normal file
59
src/entity/master_pass.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use super::{locale::LocaleType, Pool};
|
||||
use sqlx::{prelude::FromRow, query, query_as, Executor, MySql};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, FromRow, Eq)]
|
||||
pub struct MasterPass {
|
||||
pub user_id: u64,
|
||||
pub salt: Vec<u8>,
|
||||
pub password_hash: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MasterPass {
|
||||
// Inserts the master password into DB
|
||||
#[inline]
|
||||
pub async fn insert(&self, pool: &Pool, locale: LocaleType) -> super::Result<()> {
|
||||
let locale: u8 = locale.into();
|
||||
query!(
|
||||
"INSERT INTO master_pass VALUES (?, ?, ?, ?)",
|
||||
self.user_id,
|
||||
self.salt,
|
||||
self.password_hash,
|
||||
locale
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
/// Gets the master password from the database
|
||||
#[inline]
|
||||
pub async fn get(user_id: u64, pool: &Pool) -> super::Result<Option<Self>> {
|
||||
query_as("SELECT user_id, salt, password_hash FROM master_pass WHERE user_id = ?")
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Checks if the master password for the user exists
|
||||
#[inline]
|
||||
pub async fn exists(user_id: u64, pool: &Pool) -> super::Result<bool> {
|
||||
query_as::<_, (bool,)>(
|
||||
"SELECT EXISTS(SELECT * FROM master_pass WHERE user_id = ? LIMIT 1) as value",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map(|(exists,)| exists)
|
||||
}
|
||||
|
||||
/// Removes a master password of the user from the database
|
||||
pub async fn remove(
|
||||
user_id: u64,
|
||||
pool: impl Executor<'_, Database = MySql>,
|
||||
) -> super::Result<()> {
|
||||
query!("DELETE FROM master_pass WHERE user_id = ?", user_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
}
|
12
src/entity/mod.rs
Normal file
12
src/entity/mod.rs
Normal file
@ -0,0 +1,12 @@
|
||||
// This is fine, because all errors can only be caused by the database errors and the docs would get repetative very quickly
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
|
||||
pub mod account;
|
||||
pub mod locale;
|
||||
pub mod master_pass;
|
||||
|
||||
pub use sqlx::{mysql::MySqlPool as Pool, Result};
|
||||
|
||||
pub async fn migrate(pool: &Pool) -> Result<(), sqlx::migrate::MigrateError> {
|
||||
sqlx::migrate!().run(pool).await
|
||||
}
|
37
src/main.rs
37
src/main.rs
@ -1,7 +1,9 @@
|
||||
mod callbacks;
|
||||
mod commands;
|
||||
mod cryptography;
|
||||
mod default;
|
||||
mod delete_mesage_handler;
|
||||
mod entity;
|
||||
mod errors;
|
||||
mod filter_user_info;
|
||||
mod locales;
|
||||
@ -13,18 +15,25 @@ mod prelude;
|
||||
mod state;
|
||||
|
||||
use anyhow::{Error, Result};
|
||||
use dotenvy::dotenv;
|
||||
use prelude::*;
|
||||
use std::{env, sync::Arc};
|
||||
use teloxide::{adaptors::throttle::Limits, dispatching::dialogue::InMemStorage, filter_command};
|
||||
|
||||
use crate::callbacks::CallbackCommand;
|
||||
|
||||
fn get_dispatcher(
|
||||
token: String,
|
||||
db: Pool,
|
||||
) -> Dispatcher<Throttle<Bot>, crate::Error, teloxide::dispatching::DefaultKey> {
|
||||
use dptree::{case, deps};
|
||||
db: prelude::Pool,
|
||||
) -> teloxide::prelude::Dispatcher<
|
||||
teloxide::adaptors::Throttle<teloxide::Bot>,
|
||||
crate::Error,
|
||||
teloxide::dispatching::DefaultKey,
|
||||
> {
|
||||
use callbacks::CallbackCommand;
|
||||
use commands::Command;
|
||||
use state::State;
|
||||
use teloxide::{
|
||||
adaptors::throttle::Limits,
|
||||
dispatching::dialogue::InMemStorage,
|
||||
dptree::{case, deps},
|
||||
filter_command,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
let bot = Bot::new(token).throttle(Limits::default());
|
||||
|
||||
@ -73,28 +82,30 @@ fn get_dispatcher(
|
||||
.branch(case![CallbackCommand::ChangeLocale(locale)].endpoint(callbacks::change_locale));
|
||||
|
||||
let handler = dptree::entry()
|
||||
.map_async(Locale::from_update)
|
||||
.map_async(locales::Locale::from_update)
|
||||
.enter_dialogue::<Update, InMemStorage<State>, State>()
|
||||
.branch(message_handler)
|
||||
.branch(callback_handler);
|
||||
|
||||
Dispatcher::builder(bot, handler)
|
||||
.dependencies(deps![db, InMemStorage::<State>::new()])
|
||||
.error_handler(Arc::from(errors::ErrorHandler))
|
||||
.error_handler(std::sync::Arc::from(errors::ErrorHandler))
|
||||
.enable_ctrlc_handler()
|
||||
.build()
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let _ = dotenv();
|
||||
use std::env;
|
||||
|
||||
let _ = dotenvy::dotenv();
|
||||
errors::init_logger();
|
||||
|
||||
locales::LocaleStore::init();
|
||||
|
||||
let token = env::var("TOKEN").expect("expected TOKEN in the enviroment");
|
||||
let database_url = env::var("DATABASE_URL").expect("expected DATABASE_URL in the enviroment");
|
||||
let pool = Pool::connect(&database_url).await?;
|
||||
let pool = entity::Pool::connect(&database_url).await?;
|
||||
|
||||
entity::migrate(&pool).await?;
|
||||
|
||||
|
@ -1,13 +1,14 @@
|
||||
pub use crate::{
|
||||
commands::Command,
|
||||
pub(crate) use crate::cryptography::{
|
||||
self, account::Decrypted as DecryptedAccount, validate_field,
|
||||
};
|
||||
pub(crate) use crate::entity::{self, account::Account, master_pass::MasterPass, Pool};
|
||||
pub(crate) use crate::{
|
||||
errors::{handle_error, NoUserInfo},
|
||||
first_handler, handler,
|
||||
locales::{Locale, LocaleRef},
|
||||
locales::LocaleRef,
|
||||
markups::*,
|
||||
models::*,
|
||||
state::{Handler, MainDialogue, MessageIds, PackagedHandler, State},
|
||||
};
|
||||
pub use cryptography::prelude::*;
|
||||
pub use entity::{prelude::*, Pool};
|
||||
pub use futures::{StreamExt, TryStreamExt};
|
||||
pub use teloxide::{adaptors::Throttle, prelude::*};
|
||||
pub(crate) use futures::{StreamExt, TryStreamExt};
|
||||
pub(crate) use teloxide::{adaptors::Throttle, prelude::*};
|
||||
|
Reference in New Issue
Block a user