from __future__ import annotations import tinysql import typing from datetime import datetime, timezone from urllib.parse import urlparse from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize from .. import logger as logging from ..misc import get_app if typing.TYPE_CHECKING: from collections.abc import Iterator from tinysql import Cursor, Row from typing import Any from .application import Application from ..misc import Message RELAY_SOFTWARE = [ 'activityrelay', # https://git.pleroma.social/pleroma/relay 'activity-relay', # https://github.com/yukimochi/Activity-Relay 'aoderelay', # https://git.asonix.dog/asonix/relay 'feditools-relay' # https://git.ptzo.gdn/feditools/relay ] class Connection(tinysql.Connection): @property def app(self) -> Application: return get_app() def distill_inboxes(self, message: Message) -> Iterator[str]: src_domains = { message.domain, urlparse(message.object_id).netloc } for inbox in self.execute('SELECT * FROM inboxes'): if inbox['domain'] not in src_domains: yield inbox['inbox'] def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor: return self.execute(self.database.prepared_statements[name], params) def get_config(self, key: str) -> Any: if key not in CONFIG_DEFAULTS: raise KeyError(key) with self.exec_statement('get-config', {'key': key}) as cur: if not (row := cur.one()): return get_default_value(key) if row['value']: return deserialize(row['key'], row['value']) return None def get_config_all(self) -> dict[str, Any]: with self.exec_statement('get-config-all') as cur: db_config = {row['key']: row['value'] for row in cur} config = {} for key, data in CONFIG_DEFAULTS.items(): try: config[key] = deserialize(key, db_config[key]) except KeyError: if key == 'schema-version': config[key] = 0 else: config[key] = data[1] return config def put_config(self, key: str, value: Any) -> Any: if key not in CONFIG_DEFAULTS: raise KeyError(key) if key == 'private-key': self.app.signer = value elif key == 'log-level': value = logging.LogLevel.parse(value) logging.set_level(value) params = { 'key': key, 'value': serialize(key, value) if value is not None else None, 'type': get_default_type(key) } with self.exec_statement('put-config', params): return value def get_inbox(self, value: str) -> Row: with self.exec_statement('get-inbox', {'value': value}) as cur: return cur.one() def put_inbox(self, domain: str, inbox: str, actor: str | None = None, followid: str | None = None, software: str | None = None) -> Row: params = { 'domain': domain, 'inbox': inbox, 'actor': actor, 'followid': followid, 'software': software, 'created': datetime.now(tz = timezone.utc) } with self.exec_statement('put-inbox', params) as cur: return cur.one() def update_inbox(self, inbox: str, actor: str | None = None, followid: str | None = None, software: str | None = None) -> Row: if not (actor or followid or software): raise ValueError('Missing "actor", "followid", and/or "software"') data = {} if actor: data['actor'] = actor if followid: data['followid'] = followid if software: data['software'] = software statement = tinysql.Update('inboxes', data, inbox = inbox) with self.query(statement): return self.get_inbox(inbox) def del_inbox(self, value: str) -> bool: with self.exec_statement('del-inbox', {'value': value}) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return cur.modified_row_count == 1 def get_domain_ban(self, domain: str) -> Row: if domain.startswith('http'): domain = urlparse(domain).netloc with self.exec_statement('get-domain-ban', {'domain': domain}) as cur: return cur.one() def put_domain_ban(self, domain: str, reason: str | None = None, note: str | None = None) -> Row: params = { 'domain': domain, 'reason': reason, 'note': note, 'created': datetime.now(tz = timezone.utc) } with self.exec_statement('put-domain-ban', params) as cur: return cur.one() def update_domain_ban(self, domain: str, reason: str | None = None, note: str | None = None) -> tinysql.Row: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') params = {} if reason: params['reason'] = reason if note: params['note'] = note statement = tinysql.Update('domain_bans', params, domain = domain) with self.query(statement) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return self.get_domain_ban(domain) def del_domain_ban(self, domain: str) -> bool: with self.exec_statement('del-domain-ban', {'domain': domain}) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return cur.modified_row_count == 1 def get_software_ban(self, name: str) -> Row: with self.exec_statement('get-software-ban', {'name': name}) as cur: return cur.one() def put_software_ban(self, name: str, reason: str | None = None, note: str | None = None) -> Row: params = { 'name': name, 'reason': reason, 'note': note, 'created': datetime.now(tz = timezone.utc) } with self.exec_statement('put-software-ban', params) as cur: return cur.one() def update_software_ban(self, name: str, reason: str | None = None, note: str | None = None) -> tinysql.Row: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') params = {} if reason: params['reason'] = reason if note: params['note'] = note statement = tinysql.Update('software_bans', params, name = name) with self.query(statement) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return self.get_software_ban(name) def del_software_ban(self, name: str) -> bool: with self.exec_statement('del-software-ban', {'name': name}) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return cur.modified_row_count == 1 def get_domain_whitelist(self, domain: str) -> Row: with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur: return cur.one() def put_domain_whitelist(self, domain: str) -> Row: params = { 'domain': domain, 'created': datetime.now(tz = timezone.utc) } with self.exec_statement('put-domain-whitelist', params) as cur: return cur.one() def del_domain_whitelist(self, domain: str) -> bool: with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur: if cur.modified_row_count > 1: raise ValueError('More than one row was modified') return cur.modified_row_count == 1