177 lines
5.1 KiB
Python
177 lines
5.1 KiB
Python
|
import os
|
||
|
import sys
|
||
|
import threading
|
||
|
from dataclasses import dataclass
|
||
|
from datetime import datetime, timedelta
|
||
|
from typing import Optional
|
||
|
from uuid import uuid4, UUID
|
||
|
|
||
|
import bcrypt
|
||
|
import psycopg
|
||
|
|
||
|
from app.fallback import MockConnection
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Letter:
|
||
|
id: int
|
||
|
title: str
|
||
|
contents: str
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class User:
|
||
|
id: int
|
||
|
email: str
|
||
|
tier: int
|
||
|
|
||
|
def in_free_tier(self) -> bool:
|
||
|
return self.tier == 1
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class UserWithHash:
|
||
|
user: User
|
||
|
password_hash: str
|
||
|
|
||
|
|
||
|
host = os.environ.get('UNDERCOVER_POSTGRES_HOST')
|
||
|
db_name = os.environ.get('UNDERCOVER_POSTGRES_DBNAME')
|
||
|
port = os.environ.get('UNDERCOVER_POSTGRES_PORT')
|
||
|
db_user = os.environ.get('UNDERCOVER_POSTGRES_USER')
|
||
|
|
||
|
db_available = host and db_name and port and db_user and os.environ.get('UNDERCOVER_POSTGRES_PASSWORD')
|
||
|
|
||
|
if db_available:
|
||
|
def connect() -> psycopg.Connection:
|
||
|
return psycopg.connect(
|
||
|
host=host,
|
||
|
dbname=db_name,
|
||
|
port=port,
|
||
|
user=db_user,
|
||
|
password=os.environ.get('UNDERCOVER_POSTGRES_PASSWORD'))
|
||
|
else:
|
||
|
sys.stderr.write('Database login not configured: DB access is disabled.\n')
|
||
|
sys.stderr.write(' To enable, ensure UNDERCOVER_POSTGRES_{HOST,DBNAME,PORT,USER,PASSWORD} are set.\n')
|
||
|
|
||
|
def connect() -> object:
|
||
|
return MockConnection()
|
||
|
|
||
|
|
||
|
def login(user_email: str, password: str) -> bool:
|
||
|
pw_bytes: bytes = password.encode('utf-8')
|
||
|
user = __get_user_with_hash(user_email)
|
||
|
if user:
|
||
|
return bcrypt.checkpw(pw_bytes, user.password_hash.encode('utf-8'))
|
||
|
return False
|
||
|
|
||
|
|
||
|
def __gen_pw_hash(password: str) -> str:
|
||
|
pw_bytes = password.encode('utf-8')
|
||
|
salt = bcrypt.gensalt()
|
||
|
pw_hash = bcrypt.hashpw(pw_bytes, salt)
|
||
|
return pw_hash.decode('utf-8')
|
||
|
|
||
|
|
||
|
def add_user(username: str, password: str) -> None:
|
||
|
pw_hash = __gen_pw_hash(password)
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("INSERT INTO users(email, password) VALUES (%s, %s)", (username, pw_hash))
|
||
|
con.commit()
|
||
|
|
||
|
|
||
|
def delete_user(username: str) -> None:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("DELETE FROM users WHERE email = %s", (username,))
|
||
|
con.commit()
|
||
|
|
||
|
|
||
|
def add_letter(user_id: int, letter_title: str, letter_content: str) -> None:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("INSERT INTO letter_data(user_id, letter_name, letter_data) VALUES (%s, %s, %s)",
|
||
|
(user_id, letter_title, letter_content))
|
||
|
con.commit()
|
||
|
|
||
|
|
||
|
def edit_letter(letter_id: int, letter_title: str, letter_content: str) -> None:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("UPDATE letter_data SET letter_name = %s, letter_data = %s WHERE id = %s",
|
||
|
(letter_title, letter_content, letter_id))
|
||
|
con.commit()
|
||
|
|
||
|
|
||
|
def get_user_letters(user_id: int) -> list[Letter]:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SELECT id, letter_name, letter_data FROM letter_data WHERE user_id = %s", (str(user_id),))
|
||
|
return list(map(lambda row: Letter(row[0], row[1], row[2]), cur.fetchall()))
|
||
|
|
||
|
|
||
|
def get_user(email: str) -> Optional[User]:
|
||
|
user_with_hash = __get_user_with_hash(email)
|
||
|
if user_with_hash:
|
||
|
return user_with_hash.user
|
||
|
return None
|
||
|
|
||
|
|
||
|
def __get_user_with_hash(email: str) -> Optional[UserWithHash]:
|
||
|
"""
|
||
|
:param email:
|
||
|
:return: User without their password_hash
|
||
|
"""
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SELECT id, password, tier FROM users WHERE users.email ILIKE %s", (email,))
|
||
|
row = cur.fetchone()
|
||
|
if row:
|
||
|
user_id, password, tier = row
|
||
|
return UserWithHash(User(user_id, email, tier), password)
|
||
|
return None
|
||
|
|
||
|
|
||
|
RESET_TIME = timedelta(minutes=-1 * 15)
|
||
|
|
||
|
|
||
|
def initiate_password_reset(email: str) -> Optional[UUID]:
|
||
|
user = get_user(email)
|
||
|
if not user:
|
||
|
return None
|
||
|
reset_id = uuid4()
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute(
|
||
|
"INSERT INTO resets(user_id, id, reset_time) VALUES (%s, %s, NOW())",
|
||
|
(user.id, reset_id)
|
||
|
)
|
||
|
threading.Timer(RESET_TIME.total_seconds(), delete_reset_row, [reset_id]).start()
|
||
|
con.commit()
|
||
|
return reset_id
|
||
|
|
||
|
|
||
|
def delete_reset_row(reset_id: UUID) -> None:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
|
||
|
con.commit()
|
||
|
|
||
|
|
||
|
def complete_reset(reset_id: str, new_password: str) -> bool:
|
||
|
with connect() as con:
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SELECT reset_time, user_id FROM resets WHERE id = %s", (reset_id,))
|
||
|
row = cur.fetchone()
|
||
|
if not row or not row[0]:
|
||
|
return False
|
||
|
reset_time, user_id = row
|
||
|
if reset_time > (datetime.utcnow() + RESET_TIME):
|
||
|
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
|
||
|
password_hash = __gen_pw_hash(new_password)
|
||
|
cur.execute("UPDATE users SET password = %s WHERE id = %s", (password_hash, user_id))
|
||
|
con.commit()
|
||
|
return True
|
||
|
return False
|