From bdc7d41d7a61001cc7ac1d26dce7240582ea7a7f Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:02:49 -0400 Subject: [PATCH] update barkshark-sql to 0.2.0-rc1 and create row classes --- pyproject.toml | 6 +- relay/__init__.py | 2 +- relay/application.py | 73 +-------------- relay/cache.py | 12 ++- relay/database/__init__.py | 2 + relay/database/connection.py | 172 +++++++++++++++++++++++------------ relay/database/schema.py | 131 +++++++++++++++----------- relay/http_client.py | 6 +- relay/manage.py | 131 +++++++++++++------------- relay/processors.py | 16 ++-- relay/views/activitypub.py | 14 +-- relay/views/api.py | 130 ++++++++++++++++---------- relay/views/frontend.py | 2 +- relay/workers.py | 6 +- 14 files changed, 374 insertions(+), 329 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f06de0..a3c9410 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.3-1", + "barkshark-lib >= 0.2.0-rc1", "barkshark-sql == 0.1.4-1", "click >= 8.1.2", "hiredis == 2.3.2", @@ -104,7 +104,3 @@ implicit_reexport = true [[tool.mypy.overrides]] module = "blib" implicit_reexport = true - -[[tool.mypy.overrides]] -module = "bsql" -implicit_reexport = true diff --git a/relay/__init__.py b/relay/__init__.py index 73e3bb4..80eb7f9 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = '0.3.3' diff --git a/relay/application.py b/relay/application.py index a3d9925..d852f29 100644 --- a/relay/application.py +++ b/relay/application.py @@ -4,30 +4,26 @@ import asyncio import multiprocessing import signal import time -import traceback from aiohttp import web -from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer -from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from bsql import Database, Row +from bsql import Database from collections.abc import Awaitable, Callable from datetime import datetime, timedelta from mimetypes import guess_type from pathlib import Path -from queue import Empty from threading import Event, Thread from typing import Any -from urllib.parse import urlparse from . import logger as logging, workers from .cache import Cache, get_cache from .config import Config from .database import Connection, get_database +from .database.schema import Instance from .http_client import HttpClient -from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource +from .misc import Message, Response, check_open_port, get_resource from .template import Template from .views import VIEWS from .views.api import handle_api_path @@ -142,7 +138,7 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: self['workers'].push_message(inbox, message, instance) @@ -286,67 +282,6 @@ class CacheCleanupThread(Thread): self.running.clear() -class PushWorker(multiprocessing.Process): - def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None: - if Application.DEFAULT is None: - raise RuntimeError('Application not setup yet') - - multiprocessing.Process.__init__(self) - - self.queue = queue - self.shutdown = multiprocessing.Event() - self.path = Application.DEFAULT.config.path - - - def stop(self) -> None: - self.shutdown.set() - - - def run(self) -> None: - asyncio.run(self.handle_queue()) - - - async def handle_queue(self) -> None: - if IS_WINDOWS: - app = Application(self.path) - client = app.client - - client.open() - app.database.connect() - app.cache.setup() - - else: - client = HttpClient() - client.open() - - while not self.shutdown.is_set(): - try: - inbox, message, instance = self.queue.get(block=True, timeout=0.1) - asyncio.create_task(client.post(inbox, message, instance)) - - except Empty: - await asyncio.sleep(0) - - except ClientSSLError as e: - logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e)) - - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.error( - 'Failed to connect to %s for message push: %s', - urlparse(inbox).netloc, str(e) - ) - - # make sure an exception doesn't bring down the worker - except Exception: - traceback.print_exc() - - if IS_WINDOWS: - app.database.disconnect() - app.cache.close() - - await client.close() - - @web.middleware async def handle_response_headers( request: web.Request, diff --git a/relay/cache.py b/relay/cache.py index e9f261b..da87cc5 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -4,7 +4,7 @@ import json import os from abc import ABC, abstractmethod -from bsql import Database +from bsql import Database, Row from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass from datetime import datetime, timedelta, timezone @@ -172,7 +172,7 @@ class SqlCache(Cache): with self._db.session(False) as conn: with conn.run('get-cache-item', params) as cur: - if not (row := cur.one()): + if not (row := cur.one(Row)): raise KeyError(f'{namespace}:{key}') row.pop('id', None) @@ -211,9 +211,11 @@ class SqlCache(Cache): with self._db.session(True) as conn: with conn.run('set-cache-item', params) as cur: - row = cur.one() - row.pop('id', None) # type: ignore[union-attr] - return Item.from_data(*tuple(row.values())) # type: ignore[union-attr] + if (row := cur.one(Row)) is None: + raise RuntimeError("Cache item not set") + + row.pop('id', None) + return Item.from_data(*tuple(row.values())) def delete(self, namespace: str, key: str) -> None: diff --git a/relay/database/__init__.py b/relay/database/__init__.py index becd456..545f822 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]: 'tables': TABLES } + db: Database[Connection] + if config.db_type == 'sqlite': db = Database.sqlite(config.sqlite_path, **options) diff --git a/relay/database/connection.py b/relay/database/connection.py index 864ad27..006a907 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -2,12 +2,13 @@ from __future__ import annotations from argon2 import PasswordHasher from bsql import Connection as SqlConnection, Row, Update -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from uuid import uuid4 +from . import schema from .config import ( THEMES, ConfigData @@ -37,14 +38,14 @@ class Connection(SqlConnection): return get_app() - def distill_inboxes(self, message: Message) -> Iterator[Row]: + def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]: src_domains = { message.domain, urlparse(message.object_id).netloc } for instance in self.get_inboxes(): - if instance['domain'] not in src_domains: + if instance.domain not in src_domains: yield instance @@ -52,7 +53,7 @@ class Connection(SqlConnection): key = key.replace('_', '-') with self.run('get-config', {'key': key}) as cur: - if not (row := cur.one()): + if (row := cur.one(Row)) is None: return ConfigData.DEFAULT(key) data = ConfigData() @@ -61,8 +62,8 @@ class Connection(SqlConnection): def get_config_all(self) -> ConfigData: - with self.run('get-config-all', None) as cur: - return ConfigData.from_rows(tuple(cur.all())) + rows = tuple(self.run('get-config-all', None).all(schema.Row)) + return ConfigData.from_rows(rows) def put_config(self, key: str, value: Any) -> Any: @@ -99,14 +100,13 @@ class Connection(SqlConnection): return data.get(key) - def get_inbox(self, value: str) -> Row: + def get_inbox(self, value: str) -> schema.Instance | None: with self.run('get-inbox', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.Instance) - def get_inboxes(self) -> Sequence[Row]: - with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: - return tuple(cur.all()) + def get_inboxes(self) -> Iterator[schema.Instance]: + return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) def put_inbox(self, @@ -115,7 +115,7 @@ class Connection(SqlConnection): actor: str | None = None, followid: str | None = None, software: str | None = None, - accepted: bool = True) -> Row: + accepted: bool = True) -> schema.Instance: params: dict[str, Any] = { 'inbox': inbox, @@ -125,7 +125,7 @@ class Connection(SqlConnection): 'accepted': accepted } - if not self.get_inbox(domain): + if self.get_inbox(domain) is None: if not inbox: raise ValueError("Missing inbox") @@ -133,14 +133,20 @@ class Connection(SqlConnection): params['created'] = datetime.now(tz = timezone.utc) with self.run('put-inbox', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert instance: {domain}") + + return row for key, value in tuple(params.items()): if value is None: del params[key] with self.update('inboxes', params, domain = domain) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to update instance: {domain}") + + return row def del_inbox(self, value: str) -> bool: @@ -151,24 +157,23 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_request(self, domain: str) -> Row: + def get_request(self, domain: str) -> schema.Instance | None: with self.run('get-request', {'domain': domain}) as cur: - if not (row := cur.one()): - raise KeyError(domain) - - return row + return cur.one(schema.Instance) - def get_requests(self) -> Sequence[Row]: - with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur: - return tuple(cur.all()) + def get_requests(self) -> Iterator[schema.Instance]: + return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) - def put_request_response(self, domain: str, accepted: bool) -> Row: - instance = self.get_request(domain) + def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: + if (instance := self.get_request(domain)) is None: + raise KeyError(domain) if not accepted: - self.del_inbox(domain) + if not self.del_inbox(domain): + raise RuntimeError(f'Failed to delete request: {domain}') + return instance params = { @@ -177,21 +182,28 @@ class Connection(SqlConnection): } with self.run('put-inbox-accept', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert response for domain: {domain}") + + return row - def get_user(self, value: str) -> Row: + def get_user(self, value: str) -> schema.User | None: with self.run('get-user', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.User) - def get_user_by_token(self, code: str) -> Row: + def get_user_by_token(self, code: str) -> schema.User | None: with self.run('get-user-by-token', {'code': code}) as cur: - return cur.one() # type: ignore + return cur.one(schema.User) - def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: - if self.get_user(username): + def get_users(self) -> Iterator[schema.User]: + return self.execute("SELECT * FROM users").all(schema.User) + + + def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User: + if self.get_user(username) is not None: data: dict[str, str | datetime | None] = {} if password: @@ -204,7 +216,10 @@ class Connection(SqlConnection): stmt.set_where("username", username) with self.query(stmt) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to update user: {username}") + + return row if password is None: raise ValueError('Password cannot be empty') @@ -217,25 +232,36 @@ class Connection(SqlConnection): } with self.run('put-user', data) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to insert user: {username}") + + return row def del_user(self, username: str) -> None: - user = self.get_user(username) + if (user := self.get_user(username)) is None: + raise KeyError(username) - with self.run('del-user', {'value': user['username']}): + with self.run('del-user', {'value': user.username}): pass - with self.run('del-token-user', {'username': user['username']}): + with self.run('del-token-user', {'username': user.username}): pass - def get_token(self, code: str) -> Row: + def get_token(self, code: str) -> schema.Token | None: with self.run('get-token', {'code': code}) as cur: - return cur.one() # type: ignore + return cur.one(schema.Token) - def put_token(self, username: str) -> Row: + def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: + if username is not None: + return self.select('tokens').all(schema.Token) + + return self.select('tokens', username = username).all(schema.Token) + + + def put_token(self, username: str) -> schema.Token: data = { 'code': uuid4().hex, 'user': username, @@ -243,7 +269,10 @@ class Connection(SqlConnection): } with self.run('put-token', data) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Token)) is None: + raise RuntimeError(f"Failed to insert token for user: {username}") + + return row def del_token(self, code: str) -> None: @@ -251,18 +280,22 @@ class Connection(SqlConnection): pass - def get_domain_ban(self, domain: str) -> Row: + def get_domain_ban(self, domain: str) -> schema.DomainBan | None: if domain.startswith('http'): domain = urlparse(domain).netloc with self.run('get-domain-ban', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one(schema.DomainBan) + + + def get_domain_bans(self) -> Iterator[schema.DomainBan]: + return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan) def put_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: params = { 'domain': domain, @@ -272,13 +305,16 @@ class Connection(SqlConnection): } with self.run('put-domain-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to insert domain ban: {domain}") + + return row def update_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -298,7 +334,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_domain_ban(domain) + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to update domain ban: {domain}") + + return row def del_domain_ban(self, domain: str) -> bool: @@ -309,15 +348,19 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_software_ban(self, name: str) -> Row: + def get_software_ban(self, name: str) -> schema.SoftwareBan | None: with self.run('get-software-ban', {'name': name}) as cur: - return cur.one() # type: ignore + return cur.one(schema.SoftwareBan) + + + def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: + return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan) def put_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: params = { 'name': name, @@ -327,13 +370,16 @@ class Connection(SqlConnection): } with self.run('put-software-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to insert software ban: {name}') + + return row def update_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -353,7 +399,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_software_ban(name) + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to update software ban: {name}') + + return row def del_software_ban(self, name: str) -> bool: @@ -364,19 +413,26 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_domain_whitelist(self, domain: str) -> Row: + def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: with self.run('get-domain-whitelist', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one() - def put_domain_whitelist(self, domain: str) -> Row: + def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]: + return self.execute("SELECT * FROM whitelist").all(schema.Whitelist) + + + def put_domain_whitelist(self, domain: str) -> schema.Whitelist: params = { 'domain': domain, 'created': datetime.now(tz = timezone.utc) } with self.run('put-domain-whitelist', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Whitelist)) is None: + raise RuntimeError(f'Failed to insert whitelisted domain: {domain}') + + return row def del_domain_whitelist(self, domain: str) -> bool: diff --git a/relay/database/schema.py b/relay/database/schema.py index 409ee57..1fd7003 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,61 +1,88 @@ -from bsql import Column, Table, Tables +from __future__ import annotations + +import typing + +from bsql import Column, Row, Tables from collections.abc import Callable +from datetime import datetime from .config import ConfigData -from .connection import Connection + +if typing.TYPE_CHECKING: + from .connection import Connection VERSIONS: dict[int, Callable[[Connection], None]] = {} -TABLES: Tables = Tables( - Table( - 'config', - Column('key', 'text', primary_key = True, unique = True, nullable = False), - Column('value', 'text'), - Column('type', 'text', default = 'str') - ), - Table( - 'inboxes', - Column('domain', 'text', primary_key = True, unique = True, nullable = False), - Column('actor', 'text', unique = True), - Column('inbox', 'text', unique = True, nullable = False), - Column('followid', 'text'), - Column('software', 'text'), - Column('accepted', 'boolean'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'whitelist', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('created', 'timestamp') - ), - Table( - 'domain_bans', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'software_bans', - Column('name', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'users', - Column('username', 'text', primary_key = True, unique = True, nullable = False), - Column('hash', 'text', nullable = False), - Column('handle', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'tokens', - Column('code', 'text', primary_key = True, unique = True, nullable = False), - Column('user', 'text', nullable = False), - Column('created', 'timestmap', nullable = False) - ) -) +TABLES = Tables() + + +@TABLES.add_row +class Config(Row): + key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) + value: Column[str] = Column('value', 'text') + type: Column[str] = Column('type', 'text', default = 'str') + + +@TABLES.add_row +class Instance(Row): + table_name: str = 'inboxes' + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = False) + actor: Column[str] = Column('actor', 'text', unique = True) + inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) + followid: Column[str] = Column('followid', 'text') + software: Column[str] = Column('software', 'text') + accepted: Column[datetime] = Column('accepted', 'boolean') + created: Column[datetime] = Column('created', 'timestamp', nullable = False) + + +@TABLES.add_row +class Whitelist(Row): + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class DomainBan(Row): + table_name: str = 'domain_bans' + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class SoftwareBan(Row): + table_name: str = 'software_bans' + + name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class User(Row): + table_name: str = 'users' + + username: Column[str] = Column( + 'username', 'text', primary_key = True, unique = True, nullable = False) + hash: Column[str] = Column('hash', 'text', nullable = False) + handle: Column[str] = Column('handle', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class Token(Row): + table_name: str = 'tokens' + + code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) + user: Column[str] = Column('user', 'text', nullable = False) + created: Column[datetime] = Column('created', 'timestamp') def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: diff --git a/relay/http_client.py b/relay/http_client.py index 610b8a9..05a6565 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -5,11 +5,11 @@ import json from aiohttp import ClientSession, ClientTimeout, TCPConnector from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from blib import JsonBase -from bsql import Row from typing import TYPE_CHECKING, Any, TypeVar, overload from . import __version__, logger as logging from .cache import Cache +from .database.schema import Instance from .misc import MIMETYPES, Message, get_app if TYPE_CHECKING: @@ -184,12 +184,12 @@ class HttpClient: return None - async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: + async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: if not self._session: raise RuntimeError('Client not open') # akkoma and pleroma do not support HS2019 and other software still needs to be tested - if instance and instance['software'] in SUPPORTS_HS2019: + if instance is not None and instance.software in SUPPORTS_HS2019: algorithm = AlgorithmType.HS2019 else: diff --git a/relay/manage.py b/relay/manage.py index cb2b099..81f546e 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -6,7 +6,6 @@ import click import json import os -from bsql import Row from pathlib import Path from shutil import copyfile from typing import Any @@ -17,7 +16,7 @@ from . import http_client as http from . import logger as logging from .application import Application from .compat import RelayConfig, RelayDatabase -from .database import RELAY_SOFTWARE, get_database +from .database import RELAY_SOFTWARE, get_database, schema from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message @@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None: click.echo('Users:') with ctx.obj.database.session() as conn: - for user in conn.execute('SELECT * FROM users'): - click.echo(f'- {user["username"]}') + for row in conn.get_users(): + click.echo(f'- {row.username}') @cli_user.command('create') @@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: 'Create a new local user' with ctx.obj.database.session() as conn: - if conn.get_user(username): + if conn.get_user(username) is not None: click.echo(f'User already exists: {username}') return @@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None: 'Delete a local user' with ctx.obj.database.session() as conn: - if not conn.get_user(username): + if conn.get_user(username) is None: click.echo(f'User does not exist: {username}') return @@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None: click.echo(f'Tokens for "{username}":') with ctx.obj.database.session() as conn: - for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): - click.echo(f'- {token["code"]}') + for row in conn.get_tokens(username): + click.echo(f'- {row.code}') @cli_user.command('create-token') @@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None: 'Create a new API token for a user' with ctx.obj.database.session() as conn: - if not (user := conn.get_user(username)): + if (user := conn.get_user(username)) is None: click.echo(f'User does not exist: {username}') return - token = conn.put_token(user['username']) + token = conn.put_token(user.username) - click.echo(f'New token for "{username}": {token["code"]}') + click.echo(f'New token for "{username}": {token.code}') @cli_user.command('delete-token') @@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None: 'Delete an API token' with ctx.obj.database.session() as conn: - if not conn.get_token(code): + if conn.get_token(code) is None: click.echo('Token does not exist') return @@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None: click.echo('Connected to the following instances or relays:') with ctx.obj.database.session() as conn: - for inbox in conn.get_inboxes(): - click.echo(f'- {inbox["inbox"]}') + for row in conn.get_inboxes(): + click.echo(f'- {row.inbox}') @cli_inbox.command('follow') @@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None: 'Follow an actor (Relay must be running)' + instance: schema.Instance | None = None + with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)) is not None: + inbox = instance.inbox else: if not actor.startswith('http'): actor = f'https://{actor}/actor' - if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): + if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None: click.echo(f'Failed to fetch actor: {actor}') return @@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: actor = actor ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent follow message to actor: {actor}') @@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: 'Unfollow an actor (Relay must be running)' - inbox_data: Row | None = None + instance: schema.Instance | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)): + inbox = instance.inbox message = Message.new_unfollow( host = ctx.obj.config.domain, actor = actor, - follow = inbox_data['followid'] + follow = instance.followid ) else: @@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: } ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent unfollow message to: {actor}') @@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None: click.echo('Follow requests:') with ctx.obj.database.session() as conn: - for instance in conn.get_requests(): - date = instance['created'].strftime('%Y-%m-%d') - click.echo(f'- [{date}] {instance["domain"]}') + for row in conn.get_requests(): + date = row.created.strftime('%Y-%m-%d') + click.echo(f'- [{date}] {row.domain}') @cli_request.command('accept') @@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None: message = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = True ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) - if instance['software'] != 'mastodon': + if instance.software != 'mastodon': message = Message.new_follow( host = ctx.obj.config.domain, - actor = instance['actor'] + actor = instance.actor ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) @cli_request.command('deny') @@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None: response = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = False ) - asyncio.run(http.post(instance['inbox'], response, instance)) + asyncio.run(http.post(instance.inbox, response, instance)) @cli.group('instance') @@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None: click.echo('Banned domains:') with ctx.obj.database.session() as conn: - for instance in conn.execute('SELECT * FROM domain_bans'): - if instance['reason']: - click.echo(f'- {instance["domain"]} ({instance["reason"]})') + for row in conn.get_domain_bans(): + if row.reason is not None: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {instance["domain"]}') + click.echo(f'- {row.domain}') @cli_instance.command('ban') @@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> 'Ban an instance and remove the associated inbox if it exists' with ctx.obj.database.session() as conn: - if conn.get_domain_ban(domain): + if conn.get_domain_ban(domain) is not None: click.echo(f'Domain already banned: {domain}') return @@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None: 'Unban an instance' with ctx.obj.database.session() as conn: - if not conn.del_domain_ban(domain): + if conn.del_domain_ban(domain) is None: click.echo(f'Instance wasn\'t banned: {domain}') return @@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) click.echo(f'Updated domain ban: {domain}') - if row['reason']: - click.echo(f'- {row["domain"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {row["domain"]}') + click.echo(f'- {row.domain}') @cli.group('software') @@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None: click.echo('Banned software:') with ctx.obj.database.session() as conn: - for software in conn.execute('SELECT * FROM software_bans'): - if software['reason']: - click.echo(f'- {software["name"]} ({software["reason"]})') + for row in conn.get_software_bans(): + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {software["name"]}') + click.echo(f'- {row.name}') @cli_software.command('ban') @@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context, with ctx.obj.database.session() as conn: if name == 'RELAYS': - for software in RELAY_SOFTWARE: - if conn.get_software_ban(software): - click.echo(f'Relay already banned: {software}') + for item in RELAY_SOFTWARE: + if conn.get_software_ban(item): + click.echo(f'Relay already banned: {item}') continue - conn.put_software_ban(software, reason or 'relay', note) + conn.put_software_ban(item, reason or 'relay', note) click.echo('Banned all relay software') return @@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) - click.echo(f'Updated software ban: {name}') - if row['reason']: - click.echo(f'- {row["name"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {row["name"]}') + click.echo(f'- {row.name}') @cli.group('whitelist') @@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None: click.echo('Current whitelisted domains:') with ctx.obj.database.session() as conn: - for domain in conn.execute('SELECT * FROM whitelist'): - click.echo(f'- {domain["domain"]}') + for row in conn.get_domain_whitelist(): + click.echo(f'- {row.domain}') @cli_whitelist.command('add') @@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: @cli_whitelist.command('import') @click.pass_context def cli_whitelist_import(ctx: click.Context) -> None: - 'Add all current inboxes to the whitelist' + 'Add all current instances to the whitelist' with ctx.obj.database.session() as conn: - for inbox in conn.execute('SELECT * FROM inboxes').all(): - if conn.get_domain_whitelist(inbox['domain']): - click.echo(f'Domain already in whitelist: {inbox["domain"]}') + for row in conn.get_inboxes(): + if conn.get_domain_whitelist(row.domain) is not None: + click.echo(f'Domain already in whitelist: {row.domain}') continue - conn.put_domain_whitelist(inbox['domain']) + conn.put_domain_whitelist(row.domain) click.echo('Imported whitelist from inboxes') def main() -> None: - cli(prog_name='relay') - - -if __name__ == '__main__': - click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.') + cli(prog_name='activityrelay') diff --git a/relay/processors.py b/relay/processors.py index cd742ec..4e4d96f 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None: logging.debug('>> relay: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], message, instance) + view.app.push_message(instance.inbox, message, instance) view.cache.set('handle-relay', view.message.object_id, message.id, 'str') @@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: logging.debug('>> forward: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], view.message, instance) + view.app.push_message(instance.inbox, view.message, instance) view.cache.set('handle-relay', view.message.id, message.id, 'str') @@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None: return # prevent past unfollows from removing an instance - if view.instance['followid'] and view.instance['followid'] != view.message.object_id: + if view.instance.followid and view.instance.followid != view.message.object_id: return with conn.transaction(): @@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None: with view.database.session() as conn: if view.instance: - if not view.instance['software']: - if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): + if not view.instance.software: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)): with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, software = nodeinfo.sw_name ) - if not view.instance['actor']: + if not view.instance.actor: with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, actor = view.actor.id ) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index f568d17..74b01c6 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -1,26 +1,22 @@ -from __future__ import annotations - import aputils import traceback -import typing + +from aiohttp.web import Request from .base import View, register_route from .. import logger as logging +from ..database import schema from ..misc import Message, Response from ..processors import run_processor -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from bsql import Row - @register_route('/actor', '/inbox') class ActorView(View): signature: aputils.Signature message: Message actor: Message - instancce: Row + instance: schema.Instance signer: aputils.Signer @@ -47,7 +43,7 @@ class ActorView(View): return response with self.database.session() as conn: - self.instance = conn.get_inbox(self.actor.shared_inbox) + self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] # reject if actor is banned if conn.get_domain_ban(self.actor.domain): diff --git a/relay/views/api.py b/relay/views/api.py index 074dc04..3bdc822 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -90,10 +90,10 @@ class Login(View): token = conn.put_token(data['username']) - resp = Response.new({'token': token['code']}, ctype = 'json') + resp = Response.new({'token': token.code}, ctype = 'json') resp.set_cookie( 'user-token', - token['code'], + token.code, max_age = 60 * 60 * 24 * 365, domain = self.config.domain, path = '/', @@ -117,7 +117,7 @@ class RelayInfo(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.get_inboxes()] + inboxes = [row.domain for row in conn.get_inboxes()] data = { 'domain': self.config.domain, @@ -188,7 +188,7 @@ class Config(View): class Inbox(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - data = conn.get_inboxes() + data = tuple(conn.get_inboxes()) return Response.new(data, ctype = 'json') @@ -202,7 +202,7 @@ class Inbox(View): data['domain'] = urlparse(data["actor"]).netloc with self.database.session() as conn: - if conn.get_inbox(data['domain']): + if conn.get_inbox(data['domain']) is not None: return Response.new_error(404, 'Instance already in database', 'json') data['domain'] = data['domain'].encode('idna').decode() @@ -225,7 +225,12 @@ class Inbox(View): except Exception: pass - row = conn.put_inbox(**data) # type: ignore[arg-type] + row = conn.put_inbox( + data['domain'], + actor = data.get('actor'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(row, ctype = 'json') @@ -239,10 +244,15 @@ class Inbox(View): data['domain'] = data['domain'].encode('idna').decode() - if not (instance := conn.get_inbox(data['domain'])): + if (instance := conn.get_inbox(data['domain'])) is None: return Response.new_error(404, 'Instance with domain not found', 'json') - instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] + instance = conn.put_inbox( + instance.domain, + actor = data.get('actor'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(instance, ctype = 'json') @@ -268,7 +278,7 @@ class Inbox(View): class RequestView(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - instances = conn.get_requests() + instances = tuple(conn.get_requests()) return Response.new(instances, ctype = 'json') @@ -291,20 +301,20 @@ class RequestView(View): message = Message.new_response( host = self.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = data['accept'] ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) - if data['accept'] and instance['software'] != 'mastodon': + if data['accept'] and instance.software != 'mastodon': message = Message.new_follow( host = self.config.domain, - actor = instance['actor'] + actor = instance.actor ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} return Response.new(resp_message, ctype = 'json') @@ -314,7 +324,7 @@ class RequestView(View): class DomainBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) + bans = tuple(conn.get_domain_bans()) return Response.new(bans, ctype = 'json') @@ -328,10 +338,14 @@ class DomainBan(View): data['domain'] = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_ban(data['domain']): + if conn.get_domain_ban(data['domain']) is not None: return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_domain_ban(**data) + ban = conn.put_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -343,15 +357,19 @@ class DomainBan(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() - - if not conn.get_domain_ban(data['domain']): - return Response.new_error(404, 'Domain not banned', 'json') - if not any([data.get('note'), data.get('reason')]): return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - ban = conn.update_domain_ban(**data) + data['domain'] = data['domain'].encode('idna').decode() + + if conn.get_domain_ban(data['domain']) is None: + return Response.new_error(404, 'Domain not banned', 'json') + + ban = conn.update_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -365,7 +383,7 @@ class DomainBan(View): data['domain'] = data['domain'].encode('idna').decode() - if not conn.get_domain_ban(data['domain']): + if conn.get_domain_ban(data['domain']) is None: return Response.new_error(404, 'Domain not banned', 'json') conn.del_domain_ban(data['domain']) @@ -377,7 +395,7 @@ class DomainBan(View): class SoftwareBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM software_bans').all()) + bans = tuple(conn.get_software_bans()) return Response.new(bans, ctype = 'json') @@ -389,10 +407,14 @@ class SoftwareBan(View): return data with self.database.session() as conn: - if conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is not None: return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_software_ban(**data) + ban = conn.put_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -403,14 +425,18 @@ class SoftwareBan(View): if isinstance(data, Response): return data + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + with self.database.session() as conn: - if not conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is None: return Response.new_error(404, 'Software not banned', 'json') - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - - ban = conn.update_software_ban(**data) + ban = conn.update_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -422,7 +448,7 @@ class SoftwareBan(View): return data with self.database.session() as conn: - if not conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is None: return Response.new_error(404, 'Software not banned', 'json') conn.del_software_ban(data['name']) @@ -436,7 +462,7 @@ class User(View): with self.database.session() as conn: items = [] - for row in conn.execute('SELECT * FROM users'): + for row in conn.get_users(): del row['hash'] items.append(row) @@ -450,12 +476,16 @@ class User(View): return data with self.database.session() as conn: - if conn.get_user(data['username']): + if conn.get_user(data['username']) is not None: return Response.new_error(404, 'User already exists', 'json') - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') @@ -466,9 +496,13 @@ class User(View): return data with self.database.session(True) as conn: - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') @@ -479,7 +513,7 @@ class User(View): return data with self.database.session(True) as conn: - if not conn.get_user(data['username']): + if conn.get_user(data['username']) is None: return Response.new_error(404, 'User does not exist', 'json') conn.del_user(data['username']) @@ -491,7 +525,7 @@ class User(View): class Whitelist(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - items = tuple(conn.execute('SELECT * FROM whitelist').all()) + items = tuple(conn.get_domains_whitelist()) return Response.new(items, ctype = 'json') @@ -502,13 +536,13 @@ class Whitelist(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_whitelist(data['domain']): + if conn.get_domain_whitelist(domain) is not None: return Response.new_error(400, 'Domain already added to whitelist', 'json') - item = conn.put_domain_whitelist(**data) + item = conn.put_domain_whitelist(domain) return Response.new(item, ctype = 'json') @@ -519,12 +553,12 @@ class Whitelist(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if not conn.get_domain_whitelist(data['domain']): + if conn.get_domain_whitelist(domain) is None: return Response.new_error(404, 'Domain not in whitelist', 'json') - conn.del_domain_whitelist(data['domain']) + conn.del_domain_whitelist(domain) return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 5ec16fc..cf6b338 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -202,7 +202,7 @@ class AdminConfig(View): 'message': message, 'desc': { "name": "Name of the relay to be displayed in the header of the pages and in " + - "the actor endpoint.", + "the actor endpoint.", # noqa: E131 "note": "Description of the relay to be displayed on the front page and as the " + "bio in the actor endpoint.", "theme": "Color theme to use on the web pages.", diff --git a/relay/workers.py b/relay/workers.py index 8d88ad7..4b57409 100644 --- a/relay/workers.py +++ b/relay/workers.py @@ -4,7 +4,6 @@ import typing from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from bsql import Row from dataclasses import dataclass from multiprocessing import Event, Process, Queue, Value from multiprocessing.synchronize import Event as EventType @@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType from urllib.parse import urlparse from . import application, logger as logging +from .database.schema import Instance from .http_client import HttpClient from .misc import IS_WINDOWS, Message, get_app @@ -29,7 +29,7 @@ class QueueItem: class PostItem(QueueItem): inbox: str message: Message - instance: Row | None + instance: Instance | None @property def domain(self) -> str: @@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]): self.queue.put(item) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: self.queue.put(PostItem(inbox, message, instance))