diff --git a/undercover/db.py b/undercover/db.py index d69d519..78c27e5 100644 --- a/undercover/db.py +++ b/undercover/db.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from typing import Optional import bcrypt import psycopg @@ -43,7 +44,9 @@ def connected(action): def login(user_email: str, password: str): pw_bytes: bytes = password.encode('utf-8') user = __get_user(user_email) - return bcrypt.checkpw(pw_bytes, user.password_hash.encode('utf-8')) + if user: + return bcrypt.checkpw(pw_bytes, user.password_hash.encode('utf-8')) + return False def add_user(username: str, password: str): @@ -95,12 +98,14 @@ def get_user_letters(user_id: int) -> [Letter]: return list(map(lambda row: Letter(row[0], row[1], row[2]), cur.fetchall())) -def get_user(email: str) -> User: +def get_user(email: str) -> Optional[User]: user = __get_user(email) - return User(user.id, user.email) + if user: + return User(user.id, user.email) + return None -def __get_user(email: str) -> UserWithHash: +def __get_user(email: str) -> Optional[UserWithHash]: """ :param email: :return: User without their password_hash @@ -108,8 +113,11 @@ def __get_user(email: str) -> UserWithHash: with connect() as con: cur = con.cursor() cur.execute("SELECT id, password FROM users WHERE users.email = %s", (email,)) - user_id, password = cur.fetchone() - return UserWithHash(user_id, email, password) + row = cur.fetchone() + if row: + user_id, password = row + return UserWithHash(user_id, email, password) + return None def get_users() -> [UserWithHash]: