UnderCover/undercover/db.py

191 lines
5.3 KiB
Python

import os
import sys
import threading
import types
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
from uuid import uuid4, UUID
import bcrypt
import psycopg
@dataclass
class Letter:
id: int
title: str
contents: str
@dataclass
class User:
id: int
email: str
@dataclass
class UserWithHash:
id: int
email: str
password_hash: str
host = os.environ.get('UNDERCOVER_POSTGRES_HOST')
db_name = os.environ.get('UNDERCOVER_POSTGRES_DBNAME')
port = os.environ.get('UNDERCOVER_POSTGRES_PORT')
db_user = os.environ.get('UNDERCOVER_POSTGRES_USER')
db_available = host and db_name and port and db_user and os.environ.get('UNDERCOVER_POSTGRES_PASSWORD')
if db_available:
def connect():
return psycopg.connect(
host=host,
dbname=db_name,
port=port,
user=db_user,
password=os.environ.get('UNDERCOVER_POSTGRES_PASSWORD'))
else:
sys.stderr.write('Database login not configured: DB access is disabled.\n')
sys.stderr.write(' To enable, ensure UNDERCOVER_POSTGRES_{HOST,DBNAME,PORT,USER,PASSWORD} are set.\n')
def connect():
return MockConnection()
class MockConnection:
mock_cursor = types.SimpleNamespace()
mock_cursor.execute = lambda *a: ()
mock_cursor.fetchone = lambda *a: None
mock_cursor.fetchall = lambda *a: []
def __enter__(self, *a):
return self
def __exit__(self, *a):
pass
def cursor(self):
return self.mock_cursor
def commit(self, *a):
pass
def login(user_email: str, password: str):
pw_bytes: bytes = password.encode('utf-8')
user = __get_user(user_email)
if user:
return bcrypt.checkpw(pw_bytes, user.password_hash.encode('utf-8'))
return False
def __gen_pw_hash(password: str):
pw_bytes = password.encode('utf-8')
salt = bcrypt.gensalt()
pw_hash = bcrypt.hashpw(pw_bytes, salt)
return pw_hash.decode('utf-8')
def add_user(username: str, password: str):
pw_hash = __gen_pw_hash(password)
with connect() as con:
cur = con.cursor()
cur.execute("INSERT INTO users(email, password) VALUES (%s, %s)", (username, pw_hash))
con.commit()
def delete_user(username: str):
with connect() as con:
cur = con.cursor()
cur.execute("DELETE FROM users WHERE email = %s", (username,))
con.commit()
def add_letter(user_id: int, letter_title: str, letter_content: str):
with connect() as con:
cur = con.cursor()
cur.execute("INSERT INTO letter_data(user_id, letter_name, letter_data) VALUES (%s, %s, %s)",
(user_id, letter_title, letter_content))
con.commit()
def edit_letter(letter_id: int, letter_title: str, letter_content: str):
with connect() as con:
cur = con.cursor()
cur.execute("UPDATE letter_data SET letter_name = %s, letter_data = %s WHERE id = %s",
(letter_title, letter_content, letter_id))
con.commit()
def get_user_letters(user_id: int) -> [Letter]:
with connect() as con:
cur = con.cursor()
cur.execute("SELECT id, letter_name, letter_data FROM letter_data WHERE user_id = %s", (str(user_id),))
return list(map(lambda row: Letter(row[0], row[1], row[2]), cur.fetchall()))
def get_user(email: str) -> Optional[User]:
user = __get_user(email)
if user:
return User(user.id, user.email)
return None
def __get_user(email: str) -> Optional[UserWithHash]:
"""
:param email:
:return: User without their password_hash
"""
with connect() as con:
cur = con.cursor()
cur.execute("SELECT id, password FROM users WHERE users.email ILIKE %s", (email,))
row = cur.fetchone()
if row:
user_id, password = row
return UserWithHash(user_id, email, password)
return None
RESET_TIME = timedelta(minutes=-1 * 15)
def initiate_password_reset(email: str) -> Optional[UUID]:
user = get_user(email)
if not user:
return None
reset_id = uuid4()
with connect() as con:
cur = con.cursor()
cur.execute(
"INSERT INTO resets(user_id, id, reset_time) VALUES (%s, %s, NOW())",
(user.id, reset_id)
)
threading.Timer(RESET_TIME.total_seconds(), delete_reset_row, [reset_id]).start()
con.commit()
return reset_id
def delete_reset_row(reset_id: UUID):
with connect() as con:
cur = con.cursor()
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
con.commit()
def complete_reset(reset_id: str, new_password: str):
with connect() as con:
cur = con.cursor()
cur.execute("SELECT reset_time, user_id FROM resets WHERE id = %s", (reset_id,))
row = cur.fetchone()
if not row or not row[0]:
return False
reset_time, user_id = row
if reset_time > (datetime.utcnow() + RESET_TIME):
cur.execute("DELETE FROM resets WHERE id = %s", (reset_id,))
password_hash = __gen_pw_hash(new_password)
cur.execute("UPDATE users SET password = %s WHERE id = %s", (password_hash, user_id))
con.commit()
return True
return False