From 3ebe4240bf5d182fc65ebb687ba3477eaaa9d79b Mon Sep 17 00:00:00 2001 From: StNicolay <103897650+StNicolay@users.noreply.github.com> Date: Fri, 14 Oct 2022 15:38:03 +0300 Subject: [PATCH] Changed database scripts --- src/__init__.py | 22 +++++++----- src/cryptography/other_accounts.py | 12 +++---- src/database/__init__.py | 4 +-- src/database/add.py | 54 ++++++++++++++---------------- src/database/delete.py | 15 --------- src/database/get.py | 49 +++++++++++++++------------ src/database/models.py | 32 ++++++++++++++++++ src/database/prepare.py | 41 ++++++----------------- 8 files changed, 116 insertions(+), 113 deletions(-) create mode 100644 src/database/models.py diff --git a/src/__init__.py b/src/__init__.py index 3e5bada..81bd10e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,20 +1,24 @@ import os -import mariadb from dotenv import load_dotenv +from sqlalchemy.future import Engine from . import bot, cryptography, database - __all__ = ["bot", "cryptography", "database"] +engine: Engine def main() -> None: + global engine + load_dotenv("./.env") - con = mariadb.connect( - os.getenv("DB_HOST"), - os.getenv("DB_USER"), - os.getenv("DB_PASS"), - os.getenv("DB_NAME"), - ) - database.prepare(con) + engine = database.prepare.get_engine( + host=os.getenv("DB_HOST"), + user=os.getenv("DB_USER"), + passwd=os.getenv("DB_PASS"), + db=os.getenv("DB_NAME"), + ) # type: ignore + database.prepare.prepare(engine) + bot_ = bot.create_bot(os.getenv("TG_TOKEN"), con) # type: ignore + bot_.infinity_polling() diff --git a/src/cryptography/other_accounts.py b/src/cryptography/other_accounts.py index 958c46f..e0f6fd5 100644 --- a/src/cryptography/other_accounts.py +++ b/src/cryptography/other_accounts.py @@ -8,7 +8,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -def _generate_key(salt: bytes, master_pass_hash: bytes) -> bytes: +def _generate_key(salt: bytes, master_pass: bytes) -> bytes: kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, @@ -16,17 +16,17 @@ def _generate_key(salt: bytes, master_pass_hash: bytes) -> bytes: iterations=100000, backend=default_backend(), ) - key = base64.urlsafe_b64encode(kdf.derive(master_pass_hash)) + key = base64.urlsafe_b64encode(kdf.derive(master_pass)) return key def encrypt_account_info( - login: str, passwd: str, master_pass_hash: bytes + login: str, passwd: str, master_pass: bytes ) -> tuple[bytes, bytes, bytes]: """Encrypts login and password of a user using hash of their master password as a key. Returns a tuple of encrypted login password and salt""" salt = bcrypt.gensalt() - key = _generate_key(salt, master_pass_hash) + key = _generate_key(salt, master_pass) f = Fernet(key) enc_login = f.encrypt(login.encode("utf-8")) enc_passwd = f.encrypt(passwd.encode("utf-8")) @@ -34,9 +34,9 @@ def encrypt_account_info( def decrypt_account_info( - enc_login: bytes, enc_pass: bytes, master_pass_hash: bytes, salt: bytes + enc_login: bytes, enc_pass: bytes, master_pass: bytes, salt: bytes ) -> tuple[str, str]: - key = _generate_key(salt, master_pass_hash) + key = _generate_key(salt, master_pass) f = Fernet(key) login_bytes = f.decrypt(enc_login) pass_bytes = f.decrypt(enc_pass) diff --git a/src/database/__init__.py b/src/database/__init__.py index 55723e0..2b16a00 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -1,3 +1,3 @@ -from . import add, delete, get, prepare +from . import add, delete, get, models, prepare -__all__ = ["add", "delete", "get", "prepare"] +__all__ = ["add", "delete", "get", "models", "prepare"] diff --git a/src/database/add.py b/src/database/add.py index d4e5942..6d534b3 100644 --- a/src/database/add.py +++ b/src/database/add.py @@ -1,42 +1,38 @@ -import traceback - +import sqlmodel import mariadb +from sqlalchemy.future import Engine - -def add_master_pass( - id: int, hashed_passwd: bytes, salt: bytes, con: mariadb.Connection -) -> bool: - cursor = con.cursor() - try: - cursor.execute( - "INSERT INTO master_pass (user_id, salt, passwd) VALUES (?, ?, ?)", - [id, hashed_passwd, salt], - ) - cursor.close() - except Exception: - traceback.print_exc() - return False - else: - return True +from . import models def add_account( - id: int, - acc_name: str, + engine: Engine, + user_id: int, + name: str, salt: bytes, enc_login: bytes, - enc_passwd: bytes, - con: mariadb.Connection, + enc_pass: bytes, ) -> bool: - cursor = con.cursor() + account = models.Account( + user_id=user_id, name=name, salt=salt, enc_login=enc_login, enc_pass=enc_pass + ) try: - cursor.execute( - "INSERT INTO accounts (user_id, acc_name, salt, enc_login, enc_pass) VALUES (?, ?, ?, ?, ?, ?)", - [id, acc_name, salt, enc_login, enc_passwd], - ) - cursor.close() + with sqlmodel.Session(engine) as session: + session.add(account) + session.commit() + except Exception: + return False + else: + return True + + +def add_master_pass(engine: Engine, user_id: int, salt: bytes, passwd: bytes) -> bool: + master_pass = models.MasterPass(user_id=user_id, salt=salt, passwd=passwd) + try: + with sqlmodel.Session(engine) as session: + session.add(master_pass) + session.commit() except Exception: - traceback.print_exc() return False else: return True diff --git a/src/database/delete.py b/src/database/delete.py index 4f08d94..e69de29 100644 --- a/src/database/delete.py +++ b/src/database/delete.py @@ -1,15 +0,0 @@ -import mariadb - - -def delete_master_pass(con: mariadb.Connection, user_id: int) -> None: - cursor = con.cursor() - cursor.execute("DELETE FROM master_pass WHERE user_id=?", [user_id]) - cursor.close() - - -def delete_account(con: mariadb.Connection, user_id: int, account: str): - cursor = con.cursor() - cursor.execute( - "DELETE FROM accounts WHERE user_id = ? AND acc_name = ?", [user_id, account] - ) - cursor.close() diff --git a/src/database/get.py b/src/database/get.py index a16fd66..cb90d1e 100644 --- a/src/database/get.py +++ b/src/database/get.py @@ -1,31 +1,36 @@ -import mariadb +import sqlmodel +from sqlalchemy.future import Engine + +from . import models -def get_master_pass(con: mariadb.Connection, id: int) -> tuple[bytes, bytes]: - """Returns tuple of salt and hashed master password""" - cursor = con.cursor() - cursor.execute("SELECT salt, passwd FROM master_pass IF user_id = ?", [id]) - result = cursor.fetchone() - cursor.close() - return result +def get_master_pass(engine: Engine, user_id: int) -> tuple[bytes, bytes] | None: + statement = sqlmodel.select(models.MasterPass).where( + models.MasterPass.user_id == user_id + ) + with sqlmodel.Session(engine) as session: + result = session.exec(statement).first() + print(result) + if result is None: + return + return (result.salt, result.passwd) -def get_accounts(con: mariadb.Connection, id: int) -> list[str]: - """Returns list of user accounts""" - cursor = con.cursor() - cursor.execute("SELECT acc_name FROM accounts IF user_id = ?", [id]) - return [i[0] for i in cursor.fetchall()] +def get_accounts(engine: Engine, user_id: int) -> list[str]: + statement = sqlmodel.select(models.Account).where(models.Account.user_id == user_id) + with sqlmodel.Session(engine) as session: + result = session.exec(statement) + return [account.name for account in result] def get_account_info( - id: int, name: str, con: mariadb.Connection + engine: Engine, user_id: int, name: str ) -> tuple[bytes, bytes, bytes]: - """Returns tuple of salt, login and password""" - cursor = con.cursor() - cursor.execute( - """SELECT salt, enc_login, enc_pass IF user_id = ? AND acc_name = ?""", - [id, name], + statement = sqlmodel.select(models.Account).where( + models.Account.user_id == user_id and models.Account.name == name ) - result = cursor.fetchone() - cursor.close() - return result + with sqlmodel.Session(engine) as session: + result = session.exec(statement).first() + if result is None: + return + return (result.salt, result.enc_login, result.enc_pass) diff --git a/src/database/models.py b/src/database/models.py new file mode 100644 index 0000000..8dc2d6b --- /dev/null +++ b/src/database/models.py @@ -0,0 +1,32 @@ +from typing import Optional + +import sqlmodel + + +class MasterPass(sqlmodel.SQLModel, table=True): + __tablename__ = "master_passwords" + id: Optional[int] = sqlmodel.Field(primary_key=True) + user_id: int = sqlmodel.Field(nullable=False, index=True, unique=True) + salt: bytes = sqlmodel.Field( + sa_column=sqlmodel.Column(type_=sqlmodel.VARBINARY(255), nullable=False) + ) + passwd: bytes = sqlmodel.Field( + sa_column=sqlmodel.Column(type_=sqlmodel.VARBINARY(255), nullable=False) + ) + + +class Account(sqlmodel.SQLModel, table=True): + __tablename__ = "accounts" + __table_args__ = (sqlmodel.UniqueConstraint("user_id", "name"),) + id: Optional[int] = sqlmodel.Field(primary_key=True) + user_id: int = sqlmodel.Field(nullable=False, index=True) + name: str = sqlmodel.Field(nullable=False, index=True, max_length=255) + salt: bytes = sqlmodel.Field( + sa_column=sqlmodel.Column(type_=sqlmodel.VARBINARY(255), nullable=False) + ) + enc_login: bytes = sqlmodel.Field( + sa_column=sqlmodel.Column(type_=sqlmodel.VARBINARY(255), nullable=False) + ) + enc_pass: bytes = sqlmodel.Field( + sa_column=sqlmodel.Column(type_=sqlmodel.VARBINARY(255), nullable=False) + ) diff --git a/src/database/prepare.py b/src/database/prepare.py index 40a8adc..6d8e449 100644 --- a/src/database/prepare.py +++ b/src/database/prepare.py @@ -1,35 +1,16 @@ -import mariadb +import sqlmodel +from sqlalchemy.future import Engine + +from . import models -def _create_tables(con: mariadb.Connection) -> None: - cursor = con.cursor() - cursor.execute( - """CREATE TABLE IF NOT EXISTS master_pass (user_id INT, - salt BINARY(64), - passwd BINARY(64), - PRIMARY KEY(user_id) - )""" +def get_engine(host: str, user: str, passwd: str, db: str) -> Engine: + engine = sqlmodel.create_engine( + f"mariadb+mariadbconnector://{user}:{passwd}@{host}/{db}" ) - cursor.execute( - """CREATE TABLE IF NOT EXISTS accounts(user_id INT, - acc_name VARCHAR(255), - salt BINARY(64), - enc_login BINARY(64), - enc_pass BINARY(64), - UNIQUE(acc_name, user_id) - )""" - ) - cursor.close() + print(type(engine)) + return engine -def _create_index(con: mariadb.Connection) -> None: - cursor = con.cursor() - cursor.execute( - """CREATE INDEX IF NOT EXISTS user_id_to_acc on accounts(user_id) - """ - ) - - -def prepare(con: mariadb.Connection) -> None: - _create_tables(con) - _create_index(con) +def prepare(engine: Engine) -> None: + sqlmodel.SQLModel.metadata.create_all(engine)