Compare commits
	
		
			15 Commits
		
	
	
		
			1.3
			...
			972c5577f4
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 972c5577f4 | |||
| 671286dc39 | |||
| 5dbf93013a | |||
| c051c14f1f | |||
| 3686195396 | |||
| d79b57b1f0 | |||
| 157c2c4aa2 | |||
| f4a5f51b23 | |||
| 6bc8eb1413 | |||
| 9f64305050 | |||
| 4954f39a91 | |||
| 3edeb86b6c | |||
| c7675c231f | |||
| ae88fccf13 | |||
| e29eefe40b | 
@@ -35,7 +35,7 @@ class _Accounts(pydantic.BaseModel):
 | 
				
			|||||||
def _accounts_list_to_json(accounts: Iterable[DecryptedAccount]) -> str:
 | 
					def _accounts_list_to_json(accounts: Iterable[DecryptedAccount]) -> str:
 | 
				
			||||||
    result = _Accounts(
 | 
					    result = _Accounts(
 | 
				
			||||||
        accounts=[_Account.from_usual_account(i) for i in accounts],
 | 
					        accounts=[_Account.from_usual_account(i) for i in accounts],
 | 
				
			||||||
    ).json(ensure_ascii=False)
 | 
					    ).json(ensure_ascii=False, indent=2)
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,13 +3,13 @@ import functools
 | 
				
			|||||||
from sqlalchemy.future import Engine
 | 
					from sqlalchemy.future import Engine
 | 
				
			||||||
from telebot.async_telebot import AsyncTeleBot
 | 
					from telebot.async_telebot import AsyncTeleBot
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import callback_handlers, message_handlers
 | 
					from . import callback_handlers, exception_handler, message_handlers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = ["callback_handlers", "message_handlers"]
 | 
					__all__ = ["callback_handlers", "exception_handler", "message_handlers"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_bot(token: str, engine: Engine) -> AsyncTeleBot:
 | 
					def create_bot(token: str, engine: Engine) -> AsyncTeleBot:
 | 
				
			||||||
    bot = AsyncTeleBot(token)
 | 
					    bot = AsyncTeleBot(token, exception_handler=exception_handler.Handler)
 | 
				
			||||||
    bot.register_message_handler(
 | 
					    bot.register_message_handler(
 | 
				
			||||||
        functools.partial(message_handlers.set_master_password, bot, engine),
 | 
					        functools.partial(message_handlers.set_master_password, bot, engine),
 | 
				
			||||||
        commands=["set_master_pass"],
 | 
					        commands=["set_master_pass"],
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										8
									
								
								src/bot/exception_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								src/bot/exception_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					from typing import Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Handler:
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def handle(exc: Type[BaseException]) -> None:
 | 
				
			||||||
 | 
					        traceback.print_exception(exc)
 | 
				
			||||||
@@ -18,7 +18,7 @@ states: dict[tuple[int, int], Handler] = {}
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def register_state(
 | 
					def register_state(
 | 
				
			||||||
    message: Message,
 | 
					    message: Message,
 | 
				
			||||||
    handler: Callable[[Message], Any],
 | 
					    handler: Handler,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    states[(message.chat.id, message.from_user.id)] = handler
 | 
					    states[(message.chat.id, message.from_user.id)] = handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -40,9 +40,8 @@ async def delete_message(
 | 
				
			|||||||
    *,
 | 
					    *,
 | 
				
			||||||
    sleep_time: int = 0,
 | 
					    sleep_time: int = 0,
 | 
				
			||||||
) -> bool:
 | 
					) -> bool:
 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        if sleep_time != 0:
 | 
					 | 
				
			||||||
    await asyncio.sleep(sleep_time)
 | 
					    await asyncio.sleep(sleep_time)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
        await bot.delete_message(mes.chat.id, mes.id)
 | 
					        await bot.delete_message(mes.chat.id, mes.id)
 | 
				
			||||||
    except telebot.apihelper.ApiException:
 | 
					    except telebot.apihelper.ApiException:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
 | 
					import itertools
 | 
				
			||||||
from concurrent.futures import ProcessPoolExecutor
 | 
					from concurrent.futures import ProcessPoolExecutor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import telebot
 | 
					import telebot
 | 
				
			||||||
@@ -57,21 +58,31 @@ async def get_accounts(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
async def delete_all(bot: AsyncTeleBot, engine: Engine, mes: Message) -> None:
 | 
					async def delete_all(bot: AsyncTeleBot, engine: Engine, mes: Message) -> None:
 | 
				
			||||||
    await base_handler(bot, mes)
 | 
					    await base_handler(bot, mes)
 | 
				
			||||||
 | 
					    master_pass = db.get.get_master_pass(engine, mes.from_user.id)
 | 
				
			||||||
 | 
					    if master_pass is None:
 | 
				
			||||||
 | 
					        await send_tmp_message(bot, mes.chat.id, "У вас нет мастер пароля")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
    bot_mes = await bot.send_message(
 | 
					    bot_mes = await bot.send_message(
 | 
				
			||||||
        mes.chat.id,
 | 
					        mes.chat.id,
 | 
				
			||||||
        "Вы действительно хотите удалить все ваши аккаунты? Это действие "
 | 
					        "Вы действительно хотите удалить все ваши аккаунты? Это действие "
 | 
				
			||||||
        "нельзя отменить. "
 | 
					        "нельзя отменить. "
 | 
				
			||||||
        "Отправьте YES для подтверждения",
 | 
					        "Отправьте мастер пароль для подтверждения",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    register_state(
 | 
				
			||||||
 | 
					        mes, functools.partial(_delete_all2, bot, engine, master_pass, bot_mes)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    register_state(mes, functools.partial(_delete_all2, bot, engine, bot_mes))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _delete_all2(
 | 
					async def _delete_all2(
 | 
				
			||||||
    bot: AsyncTeleBot, engine: Engine, prev_mes: Message, mes: Message
 | 
					    bot: AsyncTeleBot,
 | 
				
			||||||
 | 
					    engine: Engine,
 | 
				
			||||||
 | 
					    master_pass: db.models.MasterPass,
 | 
				
			||||||
 | 
					    prev_mes: Message,
 | 
				
			||||||
 | 
					    mes: Message,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    await base_handler(bot, mes, prev_mes)
 | 
					    await base_handler(bot, mes, prev_mes)
 | 
				
			||||||
    text = mes.text.strip()
 | 
					    text = mes.text.strip()
 | 
				
			||||||
    if text == "YES":
 | 
					    if encryption.master_pass.check_master_pass(text, master_pass):
 | 
				
			||||||
        db.delete.purge_accounts(engine, mes.from_user.id)
 | 
					        db.delete.purge_accounts(engine, mes.from_user.id)
 | 
				
			||||||
        db.delete.delete_master_pass(engine, mes.from_user.id)
 | 
					        db.delete.delete_master_pass(engine, mes.from_user.id)
 | 
				
			||||||
        await send_tmp_message(
 | 
					        await send_tmp_message(
 | 
				
			||||||
@@ -84,7 +95,7 @@ async def _delete_all2(
 | 
				
			|||||||
        await send_tmp_message(
 | 
					        await send_tmp_message(
 | 
				
			||||||
            bot,
 | 
					            bot,
 | 
				
			||||||
            mes.chat.id,
 | 
					            mes.chat.id,
 | 
				
			||||||
            "Вы отправили не YES, ничего не удалено",
 | 
					            "Вы отправили не верный мастер пароль, ничего не удалено",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -133,7 +144,8 @@ async def reset_master_pass(
 | 
				
			|||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    await base_handler(bot, mes)
 | 
					    await base_handler(bot, mes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if db.get.get_master_pass(engine, mes.from_user.id) is None:
 | 
					    master_pass = db.get.get_master_pass(engine, mes.from_user.id)
 | 
				
			||||||
 | 
					    if master_pass is None:
 | 
				
			||||||
        return await send_tmp_message(
 | 
					        return await send_tmp_message(
 | 
				
			||||||
            bot,
 | 
					            bot,
 | 
				
			||||||
            mes.chat.id,
 | 
					            mes.chat.id,
 | 
				
			||||||
@@ -142,17 +154,48 @@ async def reset_master_pass(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    bot_mes = await bot.send_message(
 | 
					    bot_mes = await bot.send_message(
 | 
				
			||||||
        mes.chat.id,
 | 
					        mes.chat.id,
 | 
				
			||||||
        "Отправьте новый мастер пароль, осторожно, все текущие аккаунты "
 | 
					        "Отправьте текущий мастер пароль",
 | 
				
			||||||
        "будут удалены навсегда",
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    register_state(
 | 
					    register_state(
 | 
				
			||||||
        mes,
 | 
					        mes,
 | 
				
			||||||
        functools.partial(_reset_master_pass2, bot, engine, bot_mes),
 | 
					        functools.partial(
 | 
				
			||||||
 | 
					            _reset_master_pass2,
 | 
				
			||||||
 | 
					            bot,
 | 
				
			||||||
 | 
					            engine,
 | 
				
			||||||
 | 
					            master_pass,
 | 
				
			||||||
 | 
					            bot_mes,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _reset_master_pass2(
 | 
					async def _reset_master_pass2(
 | 
				
			||||||
 | 
					    bot: AsyncTeleBot,
 | 
				
			||||||
 | 
					    engine: Engine,
 | 
				
			||||||
 | 
					    master_pass: db.models.MasterPass,
 | 
				
			||||||
 | 
					    prev_mes: Message,
 | 
				
			||||||
 | 
					    mes: Message,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    await base_handler(bot, mes, prev_mes)
 | 
				
			||||||
 | 
					    text = mes.text.strip()
 | 
				
			||||||
 | 
					    if text == "/cancel":
 | 
				
			||||||
 | 
					        await send_tmp_message(bot, mes.chat.id, "Успешная отмена")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not encryption.master_pass.check_master_pass(text, master_pass):
 | 
				
			||||||
 | 
					        await send_tmp_message(bot, mes.chat.id, "Неверный мастер пароль")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bot_mes = await bot.send_message(
 | 
				
			||||||
 | 
					        mes.chat.id,
 | 
				
			||||||
 | 
					        "Отправьте новый мастер пароль. Осторожно, все аккаунты будут удалены",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    register_state(
 | 
				
			||||||
 | 
					        mes,
 | 
				
			||||||
 | 
					        functools.partial(_reset_master_pass3, bot, engine, bot_mes),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def _reset_master_pass3(
 | 
				
			||||||
    bot: AsyncTeleBot, engine: Engine, prev_mes: Message, mes: Message
 | 
					    bot: AsyncTeleBot, engine: Engine, prev_mes: Message, mes: Message
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    await base_handler(bot, mes, prev_mes)
 | 
					    await base_handler(bot, mes, prev_mes)
 | 
				
			||||||
@@ -427,13 +470,14 @@ async def delete_account(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    register_state(
 | 
					    register_state(
 | 
				
			||||||
        mes,
 | 
					        mes,
 | 
				
			||||||
        functools.partial(_delete_account2, bot, engine, bot_mes),
 | 
					        functools.partial(_delete_account2, bot, engine, master_pass, bot_mes),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _delete_account2(
 | 
					async def _delete_account2(
 | 
				
			||||||
    bot: AsyncTeleBot,
 | 
					    bot: AsyncTeleBot,
 | 
				
			||||||
    engine: Engine,
 | 
					    engine: Engine,
 | 
				
			||||||
 | 
					    master_pass: db.models.MasterPass,
 | 
				
			||||||
    prev_mes: Message,
 | 
					    prev_mes: Message,
 | 
				
			||||||
    mes: Message,
 | 
					    mes: Message,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
@@ -447,27 +491,36 @@ async def _delete_account2(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    bot_mes = await bot.send_message(
 | 
					    bot_mes = await bot.send_message(
 | 
				
			||||||
        mes.from_user.id,
 | 
					        mes.from_user.id,
 | 
				
			||||||
        f'Вы уверены, что хотите удалить аккаунт "{text}"?\nОтправьте YES для '
 | 
					        f'Вы уверены, что хотите удалить аккаунт "{text}"?\nОтправьте мастер '
 | 
				
			||||||
        "подтверждения",
 | 
					        "пароль для подтверждения",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    register_state(
 | 
					    register_state(
 | 
				
			||||||
        mes,
 | 
					        mes,
 | 
				
			||||||
        functools.partial(_delete_account3, bot, engine, bot_mes, text),
 | 
					        functools.partial(
 | 
				
			||||||
 | 
					            _delete_account3,
 | 
				
			||||||
 | 
					            bot,
 | 
				
			||||||
 | 
					            engine,
 | 
				
			||||||
 | 
					            master_pass,
 | 
				
			||||||
 | 
					            bot_mes,
 | 
				
			||||||
 | 
					            text,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _delete_account3(
 | 
					async def _delete_account3(
 | 
				
			||||||
    bot: AsyncTeleBot,
 | 
					    bot: AsyncTeleBot,
 | 
				
			||||||
    engine: Engine,
 | 
					    engine: Engine,
 | 
				
			||||||
 | 
					    master_pass: db.models.MasterPass,
 | 
				
			||||||
    prev_mes: Message,
 | 
					    prev_mes: Message,
 | 
				
			||||||
    account_name: str,
 | 
					    account_name: str,
 | 
				
			||||||
    mes: Message,
 | 
					    mes: Message,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    await base_handler(bot, mes, prev_mes)
 | 
					    await base_handler(bot, mes, prev_mes)
 | 
				
			||||||
    text = mes.text.strip()
 | 
					    text = mes.text.strip()
 | 
				
			||||||
    if text != "YES":
 | 
					    if not encryption.master_pass.check_master_pass(text, master_pass):
 | 
				
			||||||
        return await send_tmp_message(bot, mes.chat.id, "Успешная отмена")
 | 
					        await send_tmp_message(bot, mes.chat.id, "Неверный пароль")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    db.delete.delete_account(engine, mes.from_user.id, account_name)
 | 
					    db.delete.delete_account(engine, mes.from_user.id, account_name)
 | 
				
			||||||
    await send_tmp_message(bot, mes.chat.id, "Аккаунт удалён")
 | 
					    await send_tmp_message(bot, mes.chat.id, "Аккаунт удалён")
 | 
				
			||||||
@@ -540,7 +593,6 @@ async def _export2(
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
            tasks.append(loop.run_in_executor(pool, function))
 | 
					            tasks.append(loop.run_in_executor(pool, function))
 | 
				
			||||||
        accounts = await asyncio.gather(*tasks)
 | 
					        accounts = await asyncio.gather(*tasks)
 | 
				
			||||||
        accounts.sort(key=lambda account: account.name)
 | 
					 | 
				
			||||||
    json_io = accounts_to_json(accounts)
 | 
					    json_io = accounts_to_json(accounts)
 | 
				
			||||||
    await bot.send_document(
 | 
					    await bot.send_document(
 | 
				
			||||||
        mes.chat.id,
 | 
					        mes.chat.id,
 | 
				
			||||||
@@ -640,22 +692,35 @@ async def _import3(
 | 
				
			|||||||
    # List of names of accounts, which failed to be added to the database
 | 
					    # List of names of accounts, which failed to be added to the database
 | 
				
			||||||
    # or failed the tests
 | 
					    # or failed the tests
 | 
				
			||||||
    failed: list[str] = []
 | 
					    failed: list[str] = []
 | 
				
			||||||
 | 
					    tasks: list[asyncio.Future[db.models.Account]] = []
 | 
				
			||||||
 | 
					    loop = asyncio.get_running_loop()
 | 
				
			||||||
 | 
					    with ProcessPoolExecutor() as pool:
 | 
				
			||||||
        for account in accounts:
 | 
					        for account in accounts:
 | 
				
			||||||
            if not check_account(account):
 | 
					            if not check_account(account):
 | 
				
			||||||
                failed.append(account.name)
 | 
					                failed.append(account.name)
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
        account = encryption.accounts.encrypt(account, text)
 | 
					            function = functools.partial(
 | 
				
			||||||
        result = db.add.add_account(engine, account)
 | 
					                encryption.accounts.encrypt,
 | 
				
			||||||
        if not result:
 | 
					                account,
 | 
				
			||||||
            failed.append(account.name)
 | 
					                text,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            tasks.append(loop.run_in_executor(pool, function))
 | 
				
			||||||
 | 
					        enc_accounts: list[db.models.Account] = await asyncio.gather(*tasks)
 | 
				
			||||||
 | 
					    results = db.add.add_accounts(engine, enc_accounts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    failed_accounts = itertools.compress(
 | 
				
			||||||
 | 
					        enc_accounts, (not result for result in results)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    failed.extend((account.name for account in failed_accounts))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if failed:
 | 
					    if failed:
 | 
				
			||||||
        mes_text = "Не удалось добавить:\n" + "\n".join(failed)
 | 
					        await send_deleteable_message(
 | 
				
			||||||
 | 
					            bot, mes.chat.id, "Не удалось добавить:\n" + "\n".join(failed)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        mes_text = "Успех"
 | 
					        await send_tmp_message(bot, mes.chat.id, "Успех")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await send_tmp_message(bot, mes.chat.id, mes_text, 10)
 | 
					    del text, mes, accounts, function, tasks, failed_accounts
 | 
				
			||||||
    del text, mes, accounts
 | 
					 | 
				
			||||||
    gc.collect()
 | 
					    gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -677,10 +742,20 @@ async def message_handler(bot: AsyncTeleBot, mes: Message) -> None:
 | 
				
			|||||||
        await delete_message(bot, mes)
 | 
					        await delete_message(bot, mes)
 | 
				
			||||||
        if mes.text.strip() == "/cancel":
 | 
					        if mes.text.strip() == "/cancel":
 | 
				
			||||||
            await send_tmp_message(bot, mes.chat.id, "Нет активного действия")
 | 
					            await send_tmp_message(bot, mes.chat.id, "Нет активного действия")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
        await send_tmp_message(
 | 
					        await send_tmp_message(
 | 
				
			||||||
            bot,
 | 
					            bot,
 | 
				
			||||||
            mes.chat.id,
 | 
					            mes.chat.id,
 | 
				
			||||||
            "Вы отправили не корректное сообщение",
 | 
					            "Вы отправили не корректное сообщение",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
        await handler(mes)
 | 
					        await handler(mes)
 | 
				
			||||||
 | 
					    except Exception:
 | 
				
			||||||
 | 
					        await send_tmp_message(
 | 
				
			||||||
 | 
					            bot,
 | 
				
			||||||
 | 
					            mes.chat.id,
 | 
				
			||||||
 | 
					            "Произошла непредвиденная ошибка",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        raise
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,27 +5,42 @@ from sqlalchemy.future import Engine
 | 
				
			|||||||
from . import models
 | 
					from . import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def add_account(engine: Engine, account: models.Account) -> bool:
 | 
					def _add_model(
 | 
				
			||||||
    """Adds account to the database. Returns true on success,
 | 
					    session: sqlmodel.Session, model: models.Account | models.MasterPass
 | 
				
			||||||
 | 
					) -> bool:
 | 
				
			||||||
 | 
					    """Adds model to the session. Returns true on success,
 | 
				
			||||||
    false otherwise"""
 | 
					    false otherwise"""
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        with sqlmodel.Session(engine) as session:
 | 
					        session.add(model)
 | 
				
			||||||
            session.add(account)
 | 
					 | 
				
			||||||
            session.commit()
 | 
					 | 
				
			||||||
    except IntegrityError:
 | 
					    except IntegrityError:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def add_account(engine: Engine, account: models.Account) -> bool:
 | 
				
			||||||
 | 
					    """Adds account to the database. Returns true on success,
 | 
				
			||||||
 | 
					    false otherwise"""
 | 
				
			||||||
 | 
					    with sqlmodel.Session(engine) as session:
 | 
				
			||||||
 | 
					        result = _add_model(session, account)
 | 
				
			||||||
 | 
					        session.commit()
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def add_master_pass(engine: Engine, master_pass: models.MasterPass) -> bool:
 | 
					def add_master_pass(engine: Engine, master_pass: models.MasterPass) -> bool:
 | 
				
			||||||
    """Adds master password the database. Returns true on success,
 | 
					    """Adds master password the database. Returns true on success,
 | 
				
			||||||
    false otherwise"""
 | 
					    false otherwise"""
 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
    with sqlmodel.Session(engine) as session:
 | 
					    with sqlmodel.Session(engine) as session:
 | 
				
			||||||
            session.add(master_pass)
 | 
					        result = _add_model(session, master_pass)
 | 
				
			||||||
        session.commit()
 | 
					        session.commit()
 | 
				
			||||||
    except IntegrityError:
 | 
					    return result
 | 
				
			||||||
        return False
 | 
					
 | 
				
			||||||
    else:
 | 
					
 | 
				
			||||||
        return True
 | 
					def add_accounts(
 | 
				
			||||||
 | 
					    engine: Engine,
 | 
				
			||||||
 | 
					    accounts: list[models.Account],
 | 
				
			||||||
 | 
					) -> list[bool]:
 | 
				
			||||||
 | 
					    with sqlmodel.Session(engine) as session:
 | 
				
			||||||
 | 
					        result = [_add_model(session, account) for account in accounts]
 | 
				
			||||||
 | 
					        session.commit()
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -36,9 +36,13 @@ def get_account_names(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def get_accounts(engine: Engine, user_id: int) -> list[models.Account]:
 | 
					def get_accounts(engine: Engine, user_id: int) -> list[models.Account]:
 | 
				
			||||||
    """Returns a list of accounts of a user"""
 | 
					    """Returns a list of accounts of a user"""
 | 
				
			||||||
    statement = sqlmodel.select(models.Account).where(
 | 
					    statement = (
 | 
				
			||||||
 | 
					        sqlmodel.select(models.Account)
 | 
				
			||||||
 | 
					        .where(
 | 
				
			||||||
            models.Account.user_id == user_id,
 | 
					            models.Account.user_id == user_id,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        .order_by(models.Account.name)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    with sqlmodel.Session(engine) as session:
 | 
					    with sqlmodel.Session(engine) as session:
 | 
				
			||||||
        result = session.exec(statement).fetchall()
 | 
					        result = session.exec(statement).fetchall()
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,13 +4,21 @@ import sqlmodel
 | 
				
			|||||||
class MasterPass(sqlmodel.SQLModel, table=True):
 | 
					class MasterPass(sqlmodel.SQLModel, table=True):
 | 
				
			||||||
    __tablename__ = "master_passwords"
 | 
					    __tablename__ = "master_passwords"
 | 
				
			||||||
    user_id: int = sqlmodel.Field(
 | 
					    user_id: int = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.INT(), primary_key=True)
 | 
					        sa_column=sqlmodel.Column(
 | 
				
			||||||
 | 
					            sqlmodel.INT(),
 | 
				
			||||||
 | 
					            primary_key=True,
 | 
				
			||||||
 | 
					            autoincrement=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    salt: bytes = sqlmodel.Field(
 | 
					    salt: bytes = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.BINARY(64), nullable=False)
 | 
					        sa_column=sqlmodel.Column(sqlmodel.BINARY(64), nullable=False),
 | 
				
			||||||
 | 
					        max_length=64,
 | 
				
			||||||
 | 
					        min_length=64,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    password_hash: bytes = sqlmodel.Field(
 | 
					    password_hash: bytes = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.BINARY(128), nullable=False)
 | 
					        sa_column=sqlmodel.Column(sqlmodel.BINARY(128), nullable=False),
 | 
				
			||||||
 | 
					        max_length=128,
 | 
				
			||||||
 | 
					        min_length=128,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -18,13 +26,17 @@ class Account(sqlmodel.SQLModel, table=True):
 | 
				
			|||||||
    __tablename__ = "accounts"
 | 
					    __tablename__ = "accounts"
 | 
				
			||||||
    __table_args__ = (sqlmodel.PrimaryKeyConstraint("user_id", "name"),)
 | 
					    __table_args__ = (sqlmodel.PrimaryKeyConstraint("user_id", "name"),)
 | 
				
			||||||
    user_id: int = sqlmodel.Field()
 | 
					    user_id: int = sqlmodel.Field()
 | 
				
			||||||
    name: str = sqlmodel.Field(max_length=255)
 | 
					    name: str = sqlmodel.Field(max_length=256)
 | 
				
			||||||
    salt: bytes = sqlmodel.Field(
 | 
					    salt: bytes = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.BINARY(64), nullable=False)
 | 
					        sa_column=sqlmodel.Column(sqlmodel.BINARY(64), nullable=False),
 | 
				
			||||||
 | 
					        max_length=64,
 | 
				
			||||||
 | 
					        min_length=64,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    enc_login: bytes = sqlmodel.Field(
 | 
					    enc_login: bytes = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.VARBINARY(256), nullable=False)
 | 
					        sa_column=sqlmodel.Column(sqlmodel.VARBINARY(256), nullable=False),
 | 
				
			||||||
 | 
					        max_length=256,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    enc_password: bytes = sqlmodel.Field(
 | 
					    enc_password: bytes = sqlmodel.Field(
 | 
				
			||||||
        sa_column=sqlmodel.Column(sqlmodel.VARBINARY(256), nullable=False)
 | 
					        sa_column=sqlmodel.Column(sqlmodel.VARBINARY(256), nullable=False),
 | 
				
			||||||
 | 
					        max_length=256,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,26 +1,43 @@
 | 
				
			|||||||
import base64
 | 
					 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					from typing import Self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cryptography.fernet import Fernet
 | 
					 | 
				
			||||||
from cryptography.hazmat.backends import default_backend
 | 
					 | 
				
			||||||
from cryptography.hazmat.primitives import hashes
 | 
					from cryptography.hazmat.primitives import hashes
 | 
				
			||||||
 | 
					from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
 | 
				
			||||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 | 
					from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..db.models import Account
 | 
					from ..db.models import Account
 | 
				
			||||||
from ..decrypted_account import DecryptedAccount
 | 
					from ..decrypted_account import DecryptedAccount
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _generate_key(salt: bytes, master_pass: bytes) -> bytes:
 | 
					class Cipher:
 | 
				
			||||||
    """Generates key for fernet encryption"""
 | 
					    def __init__(self, key: bytes) -> None:
 | 
				
			||||||
 | 
					        self._chacha = ChaCha20Poly1305(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def generate_cipher(cls, salt: bytes, password: bytes) -> Self:
 | 
				
			||||||
 | 
					        """Generates cipher which uses key derived from a given password"""
 | 
				
			||||||
        kdf = PBKDF2HMAC(
 | 
					        kdf = PBKDF2HMAC(
 | 
				
			||||||
            algorithm=hashes.SHA256(),
 | 
					            algorithm=hashes.SHA256(),
 | 
				
			||||||
            length=32,
 | 
					            length=32,
 | 
				
			||||||
            salt=salt,
 | 
					            salt=salt,
 | 
				
			||||||
            iterations=100000,
 | 
					            iterations=100000,
 | 
				
			||||||
        backend=default_backend(),
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    key = base64.urlsafe_b64encode(kdf.derive(master_pass))
 | 
					        return cls(kdf.derive(password))
 | 
				
			||||||
    return key
 | 
					
 | 
				
			||||||
 | 
					    def encrypt(self, data: bytes) -> bytes:
 | 
				
			||||||
 | 
					        nonce = os.urandom(12)
 | 
				
			||||||
 | 
					        return nonce + self._chacha.encrypt(
 | 
				
			||||||
 | 
					            nonce,
 | 
				
			||||||
 | 
					            data,
 | 
				
			||||||
 | 
					            associated_data=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decrypt(self, data: bytes) -> bytes:
 | 
				
			||||||
 | 
					        return self._chacha.decrypt(
 | 
				
			||||||
 | 
					            nonce=data[:12],
 | 
				
			||||||
 | 
					            data=data[12:],
 | 
				
			||||||
 | 
					            associated_data=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def encrypt(
 | 
					def encrypt(
 | 
				
			||||||
@@ -29,15 +46,10 @@ def encrypt(
 | 
				
			|||||||
) -> Account:
 | 
					) -> Account:
 | 
				
			||||||
    """Encrypts account using master password and returns Account object"""
 | 
					    """Encrypts account using master password and returns Account object"""
 | 
				
			||||||
    salt = os.urandom(64)
 | 
					    salt = os.urandom(64)
 | 
				
			||||||
    key = _generate_key(salt, master_pass.encode("utf-8"))
 | 
					    cipher = Cipher.generate_cipher(salt, master_pass.encode("utf-8"))
 | 
				
			||||||
    f = Fernet(key)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    enc_login = base64.urlsafe_b64decode(
 | 
					    enc_login = cipher.encrypt(account.login.encode("utf-8"))
 | 
				
			||||||
        f.encrypt(account.login.encode("utf-8")),
 | 
					    enc_password = cipher.encrypt(account.password.encode("utf-8"))
 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    enc_password = base64.urlsafe_b64decode(
 | 
					 | 
				
			||||||
        f.encrypt(account.password.encode("utf-8")),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return Account(
 | 
					    return Account(
 | 
				
			||||||
        user_id=account.user_id,
 | 
					        user_id=account.user_id,
 | 
				
			||||||
@@ -54,15 +66,10 @@ def decrypt(
 | 
				
			|||||||
) -> DecryptedAccount:
 | 
					) -> DecryptedAccount:
 | 
				
			||||||
    """Decrypts account using master password and returns
 | 
					    """Decrypts account using master password and returns
 | 
				
			||||||
    DecryptedAccount object"""
 | 
					    DecryptedAccount object"""
 | 
				
			||||||
    key = _generate_key(account.salt, master_pass.encode("utf-8"))
 | 
					    cipher = Cipher.generate_cipher(account.salt, master_pass.encode("utf-8"))
 | 
				
			||||||
    f = Fernet(key)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    login = f.decrypt(
 | 
					    login = cipher.decrypt(account.enc_login).decode("utf-8")
 | 
				
			||||||
        base64.urlsafe_b64encode(account.enc_login),
 | 
					    password = cipher.decrypt(account.enc_password).decode("utf-8")
 | 
				
			||||||
    ).decode("utf-8")
 | 
					 | 
				
			||||||
    password = f.decrypt(
 | 
					 | 
				
			||||||
        base64.urlsafe_b64encode(account.enc_password),
 | 
					 | 
				
			||||||
    ).decode("utf-8")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return DecryptedAccount(
 | 
					    return DecryptedAccount(
 | 
				
			||||||
        user_id=account.user_id,
 | 
					        user_id=account.user_id,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user