UnderCover/undercover/db.py

177 lines
5.1 KiB
Python

import os
import sys
import threading
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
from uuid import uuid4, UUID
import bcrypt
import psycopg
from undercover.fallback import MockConnection
@dataclass
class Letter:
id: int
title: str
contents: str
@dataclass
class User:
id: int
email: str
tier: int
def in_free_tier(self) -> bool:
return self.tier == 1
@dataclass
class UserWithHash:
user: User
password_hash: str
host = os.environ.get('UNDERCOVER_POSTGRES_HOST')
db_name = os.environ.get('UNDERCOVER_POSTGRES_DBNAME')
port = os.environ.get('UNDERCOVER_POSTGRES_PORT')
db_user = os.environ.get('UNDERCOVER_POSTGRES_USER')
db_available = host and db_name and port and db_user and os.environ.get('UNDERCOVER_POSTGRES_PASSWORD')
if db_available:
def connect() -> psycopg.Connection:
return psycopg.connect(
host=host,
dbname=db_name,
port=port,
user=db_user,
password=os.environ.get('UNDERCOVER_POSTGRES_PASSWORD'))
else:
sys.stderr.write('Database login not configured: DB access is disabled.\n')
sys.stderr.write(' To enable, ensure UNDERCOVER_POSTGRES_{HOST,DBNAME,PORT,USER,PASSWORD} are set.\n')
def connect() -> object:
return MockConnection()
def login(user_email: str, password: str) -> bool:
pw_bytes: bytes = password.encode('utf-8')
user = __get_user_with_hash(user_email)
if user:
return bcrypt.checkpw(pw_bytes, user.password_hash.encode('utf-8'))
return False
def __gen_pw_hash(password: str) -> str:
pw_bytes = password.encode('utf-8')
salt = bcrypt.gensalt()
pw_hash = bcrypt.hashpw(pw_bytes, salt)
return pw_hash.decode('utf-8')
def add_user(username: str, password: str) -> None:
pw_hash = __gen_pw_hash(password)
with connect() as con:
cur = con.cursor()
cur.execute("INSERT INTO users(email, password) VALUES (%s, %s)", (username, pw_hash))
con.commit()
def delete_user(username: str) -> None:
with connect() as con:
cur = con.cursor()
cur.execute("DELETE FROM users WHERE email = %s", (username,))
con.commit()
def add_letter(user_id: int, letter_title: str, letter_content: str) -> None:
with connect() as con:
cur = con.cursor()
cur.execute("INSERT INTO letter_data(user_id, letter_name, letter_data) VALUES (%s, %s, %s)",
(user_id, letter_title, letter_content))
con.commit()
def edit_letter(letter_id: int, letter_title: str, letter_content: str) -> None:
with connect() as con:
cur = con.cursor()
cur.execute("UPDATE letter_data SET letter_name = %s, letter_data = %s WHERE id = %s",
(letter_title, letter_content, letter_id))
con.commit()
def get_user_letters(user_id: int) -> list[Letter]:
with connect() as con:
cur = con.cursor()
cur.execute("SELECT id, letter_name, letter_data FROM letter_data WHERE user_id = %s", (str(user_id),))
return list(map(lambda row: Letter(row[0], row[1], row[2]), cur.fetchall()))
def get_user(email: str) -> Optional[User]:
user_with_hash = __get_user_with_hash(email)
if user_with_hash:
return user_with_hash.user
return None
def __get_user_with_hash(email: str) -> Optional[UserWithHash]:
"""
:param email:
:return: User without their password_hash
"""
with connect() as con:
cur = con.cursor()
cur.execute("SELECT id, password, tier FROM users WHERE users.email ILIKE %s", (email,))
row = cur.fetchone()
if row:
user_id, password, tier = row
return UserWithHash(User(user_id, email, tier), password)
return None
RESET_TIME = timedelta(minutes=-1 * 15)
def initiate_password_reset(email: str) -> Optional[UUID]:
user = get_user(email)
if not user:
return None
reset_id = uuid4()
with connect() as con:
cur = con.cursor()
cur.execute(
"INSERT INTO resets(user_id, id, reset_time) VALUES (%s, %s, NOW())",
(user.id, reset_id)
)
threading.Timer(RESET_TIME.total_seconds(), delete_reset_row, [reset_id]).start()
con.commit()
return reset_id
def delete_reset_row(reset_id: UUID) -> None:
with connect() as con:
cur = con.cursor()
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
con.commit()
def complete_reset(reset_id: str, new_password: str) -> bool:
with connect() as con:
cur = con.cursor()
cur.execute("SELECT reset_time, user_id FROM resets WHERE id = %s", (reset_id,))
row = cur.fetchone()
if not row or not row[0]:
return False
reset_time, user_id = row
if reset_time > (datetime.utcnow() + RESET_TIME):
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
password_hash = __gen_pw_hash(new_password)
cur.execute("UPDATE users SET password = %s WHERE id = %s", (password_hash, user_id))
con.commit()
return True
return False