from __future__ import annotations import secrets from argon2 import PasswordHasher from blib import Date, convert_to_boolean from bsql import BackendType, Connection as SqlConnection, Row, Update from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from . import schema from .config import ( THEMES, ConfigData ) from .. import logger as logging from ..misc import Message, get_app if TYPE_CHECKING: from ..application import Application class Connection(SqlConnection): hasher = PasswordHasher( encoding = "utf-8" ) @property def app(self) -> Application: return get_app() def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]: src_domains = { message.domain, urlparse(message.object_id).netloc } for instance in self.get_inboxes(): if instance.domain not in src_domains: yield instance def drop_tables(self) -> None: with self.cursor() as cur: for table in self.get_tables(): query = f"DROP TABLE IF EXISTS {table}" if self.database.backend.backend_type == BackendType.POSTGRESQL: query += " CASCADE" cur.execute(query) def fix_timestamps(self) -> None: for app in self.select("apps").all(schema.App): data = {"created": app.created.timestamp(), "accessed": app.accessed.timestamp()} self.update("apps", data, client_id = app.client_id) for item in self.select("cache"): data = {"updated": Date.parse(item["updated"]).timestamp()} self.update("cache", data, id = item["id"]) for dban in self.select("domain_bans").all(schema.DomainBan): data = {"created": dban.created.timestamp()} self.update("domain_bans", data, domain = dban.domain) for instance in self.select("inboxes").all(schema.Instance): data = {"created": instance.created.timestamp()} self.update("inboxes", data, domain = instance.domain) for sban in self.select("software_bans").all(schema.SoftwareBan): data = {"created": sban.created.timestamp()} self.update("software_bans", data, name = sban.name) for user in self.select("users").all(schema.User): data = {"created": user.created.timestamp()} self.update("users", data, username = user.username) for wlist in self.select("whitelist").all(schema.Whitelist): data = {"created": wlist.created.timestamp()} self.update("whitelist", data, domain = wlist.domain) def get_config(self, key: str) -> Any: key = key.replace("_", "-") with self.run("get-config", {"key": key}) as cur: if (row := cur.one(Row)) is None: return ConfigData.DEFAULT(key) data = ConfigData() data.set(row["key"], row["value"]) return data.get(key) def get_config_all(self) -> ConfigData: rows = tuple(self.run("get-config-all", None).all(schema.Row)) return ConfigData.from_rows(rows) def put_config(self, key: str, value: Any) -> Any: field = ConfigData.FIELD(key) match field.name: case "private_key": self.app.signer = value case "log_level": value = logging.LogLevel.parse(value) logging.set_level(value) self.app["workers"].set_log_level(value) case "approval_required": value = convert_to_boolean(value) case "whitelist_enabled": value = convert_to_boolean(value) case "theme": if value not in THEMES: raise ValueError(f"\"{value}\" is not a valid theme") data = ConfigData() data.set(key, value) params = { "key": key, "value": data.get(key, serialize = True), "type": "LogLevel" if field.type == "logging.LogLevel" else field.type } with self.run("put-config", params): pass return data.get(key) def get_inbox(self, value: str) -> schema.Instance | None: with self.run("get-inbox", {"value": value}) as cur: return cur.one(schema.Instance) def get_inboxes(self) -> Iterator[schema.Instance]: return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance) # todo: check if software is different than stored row def put_inbox(self, # noqa: E301 domain: str, inbox: str | None = None, actor: str | None = None, followid: str | None = None, software: str | None = None, accepted: bool = True) -> schema.Instance: params: dict[str, Any] = { "inbox": inbox, "actor": actor, "followid": followid, "software": software, "accepted": accepted } if self.get_inbox(domain) is None: if not inbox: raise ValueError("Missing inbox") params["domain"] = domain params["created"] = datetime.now(tz = timezone.utc) with self.run("put-inbox", params) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to insert instance: {domain}") return row for key, value in tuple(params.items()): if value is None: del params[key] with self.update("inboxes", params, domain = domain) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to update instance: {domain}") return row def del_inbox(self, value: str) -> bool: with self.run("del-inbox", {"value": value}) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") return cur.row_count == 1 def get_request(self, domain: str) -> schema.Instance | None: with self.run("get-request", {"domain": domain}) as cur: return cur.one(schema.Instance) def get_requests(self) -> Iterator[schema.Instance]: return self.execute("SELECT * FROM inboxes WHERE accepted = false").all(schema.Instance) def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: if (instance := self.get_request(domain)) is None: raise KeyError(domain) if not accepted: if not self.del_inbox(domain): raise RuntimeError(f"Failed to delete request: {domain}") return instance params = { "domain": domain, "accepted": accepted } with self.run("put-inbox-accept", params) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to insert response for domain: {domain}") return row def get_user(self, value: str) -> schema.User | None: with self.run("get-user", {"value": value}) as cur: return cur.one(schema.User) def get_user_by_token(self, token: str) -> schema.User | None: with self.run("get-user-by-token", {"token": token}) as cur: return cur.one(schema.User) def get_users(self) -> Iterator[schema.User]: return self.execute("SELECT * FROM users").all(schema.User) def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User: if self.get_user(username) is not None: data: dict[str, str | datetime | None] = {} if password: data["hash"] = self.hasher.hash(password) if handle: data["handle"] = handle stmt = Update("users", data) stmt.set_where("username", username) with self.query(stmt) as cur: if (row := cur.one(schema.User)) is None: raise RuntimeError(f"Failed to update user: {username}") return row if password is None: raise ValueError("Password cannot be empty") data = { "username": username, "hash": self.hasher.hash(password), "handle": handle, "created": datetime.now(tz = timezone.utc) } with self.run("put-user", data) as cur: if (row := cur.one(schema.User)) is None: raise RuntimeError(f"Failed to insert user: {username}") return row def del_user(self, username: str) -> None: if (user := self.get_user(username)) is None: raise KeyError(username) with self.run("del-token-user", {"username": user.username}): pass with self.run("del-user", {"username": user.username}): pass def get_app(self, client_id: str, client_secret: str, token: str | None = None) -> schema.App | None: params = { "id": client_id, "secret": client_secret } if token is not None: command = "get-app-with-token" params["token"] = token else: command = "get-app" with self.run(command, params) as cur: return cur.one(schema.App) def get_app_by_token(self, token: str) -> schema.App | None: with self.run("get-app-by-token", {"token": token}) as cur: return cur.one(schema.App) def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: params = { "name": name, "redirect_uri": redirect_uri, "website": website, "client_id": secrets.token_hex(20), "client_secret": secrets.token_hex(20), "created": Date.new_utc(), "accessed": Date.new_utc() } with self.insert("apps", params) as cur: if (row := cur.one(schema.App)) is None: raise RuntimeError(f"Failed to insert app: {name}") return row def put_app_login(self, user: schema.User) -> schema.App: params = { "name": "Web", "redirect_uri": "urn:ietf:wg:oauth:2.0:oob", "website": None, "user": user.username, "client_id": secrets.token_hex(20), "client_secret": secrets.token_hex(20), "auth_code": None, "token": secrets.token_hex(20), "created": Date.new_utc(), "accessed": Date.new_utc() } with self.insert("apps", params) as cur: if (row := cur.one(schema.App)) is None: raise RuntimeError(f"Failed to create app for \"{user.username}\"") return row def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App: data: dict[str, str | None] = {} if user is not None: data["user"] = user.username if set_auth: data["auth_code"] = secrets.token_hex(20) else: data["token"] = secrets.token_hex(20) data["auth_code"] = None params = { "client_id": app.client_id, "client_secret": app.client_secret } with self.update("apps", data, **params) as cur: # type: ignore[arg-type] if (row := cur.one(schema.App)) is None: raise RuntimeError("Failed to update row") return row def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: params = { "id": client_id, "secret": client_secret } if token is not None: command = "del-app-with-token" params["token"] = token else: command = "del-app" with self.run(command, params) as cur: if cur.row_count > 1: raise RuntimeError("More than 1 row was deleted") return cur.row_count == 0 def get_domain_ban(self, domain: str) -> schema.DomainBan | None: if domain.startswith("http"): domain = urlparse(domain).netloc with self.run("get-domain-ban", {"domain": domain}) as cur: return cur.one(schema.DomainBan) def get_domain_bans(self) -> Iterator[schema.DomainBan]: return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan) def put_domain_ban(self, domain: str, reason: str | None = None, note: str | None = None) -> schema.DomainBan: params = { "domain": domain, "reason": reason, "note": note, "created": datetime.now(tz = timezone.utc) } with self.run("put-domain-ban", params) as cur: if (row := cur.one(schema.DomainBan)) is None: raise RuntimeError(f"Failed to insert domain ban: {domain}") return row def update_domain_ban(self, domain: str, reason: str | None = None, note: str | None = None) -> schema.DomainBan: if not (reason or note): raise ValueError("\"reason\" and/or \"note\" must be specified") params = {} if reason is not None: params["reason"] = reason if note is not None: params["note"] = note statement = Update("domain_bans", params) statement.set_where("domain", domain) with self.query(statement) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") if (row := cur.one(schema.DomainBan)) is None: raise RuntimeError(f"Failed to update domain ban: {domain}") return row def del_domain_ban(self, domain: str) -> bool: with self.run("del-domain-ban", {"domain": domain}) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") return cur.row_count == 1 def get_software_ban(self, name: str) -> schema.SoftwareBan | None: with self.run("get-software-ban", {"name": name}) as cur: return cur.one(schema.SoftwareBan) def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: return self.execute("SELECT * FROM software_bans").all(schema.SoftwareBan) def put_software_ban(self, name: str, reason: str | None = None, note: str | None = None) -> schema.SoftwareBan: params = { "name": name, "reason": reason, "note": note, "created": datetime.now(tz = timezone.utc) } with self.run("put-software-ban", params) as cur: if (row := cur.one(schema.SoftwareBan)) is None: raise RuntimeError(f"Failed to insert software ban: {name}") return row def update_software_ban(self, name: str, reason: str | None = None, note: str | None = None) -> schema.SoftwareBan: if not (reason or note): raise ValueError("\"reason\" and/or \"note\" must be specified") params = {} if reason is not None: params["reason"] = reason if note is not None: params["note"] = note statement = Update("software_bans", params) statement.set_where("name", name) with self.query(statement) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") if (row := cur.one(schema.SoftwareBan)) is None: raise RuntimeError(f"Failed to update software ban: {name}") return row def del_software_ban(self, name: str) -> bool: with self.run("del-software-ban", {"name": name}) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") return cur.row_count == 1 def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: with self.run("get-domain-whitelist", {"domain": domain}) as cur: return cur.one() def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]: return self.execute("SELECT * FROM whitelist").all(schema.Whitelist) def put_domain_whitelist(self, domain: str) -> schema.Whitelist: params = { "domain": domain, "created": datetime.now(tz = timezone.utc) } with self.run("put-domain-whitelist", params) as cur: if (row := cur.one(schema.Whitelist)) is None: raise RuntimeError(f"Failed to insert whitelisted domain: {domain}") return row def del_domain_whitelist(self, domain: str) -> bool: with self.run("del-domain-whitelist", {"domain": domain}) as cur: if cur.row_count > 1: raise ValueError("More than one row was modified") return cur.row_count == 1