From 93797b639e3c489b9e9f136bd3ef133652ef7481 Mon Sep 17 00:00:00 2001 From: Izalia Mae <2908-izalia@users.noreply.git.pleroma.social> Date: Wed, 21 Feb 2024 00:00:21 +0000 Subject: [PATCH] Replace tinysql --- relay/application.py | 10 ++-- relay/cache.py | 32 +++++++------ relay/data/statements.sql | 2 +- relay/database/__init__.py | 30 ++++++------ relay/database/connection.py | 89 ++++++++++++++++++------------------ relay/database/schema.py | 10 ++-- relay/http_client.py | 2 +- relay/manage.py | 65 +++++++++++++------------- relay/processors.py | 10 ++-- relay/views/activitypub.py | 2 +- relay/views/api.py | 50 +++++++++----------- relay/views/frontend.py | 4 +- relay/views/misc.py | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 15 files changed, 154 insertions(+), 158 deletions(-) diff --git a/relay/application.py b/relay/application.py index dd6311b..cafef59 100644 --- a/relay/application.py +++ b/relay/application.py @@ -190,7 +190,7 @@ class Application(web.Application): self['proc'] = None self.cache.close() - self.database.close() + self.database.disconnect() class CacheCleanupThread(Thread): @@ -202,14 +202,10 @@ class CacheCleanupThread(Thread): def run(self) -> None: - cache = get_cache(self.app) - while self.running.is_set(): time.sleep(3600) logging.verbose("Removing old cache items") - cache.delete_old(14) - - cache.close() + self.app.cache.delete_old(14) def start(self) -> None: @@ -244,7 +240,7 @@ async def handle_access_log(request: web.Request, response: web.Response) -> Non async def handle_cleanup(app: Application) -> None: await app.client.close() app.cache.close() - app.database.close() + app.database.disconnect() async def main_gunicorn(): diff --git a/relay/cache.py b/relay/cache.py index faee1e5..3f258eb 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -168,8 +168,8 @@ class SqlCache(Cache): 'key': key } - with self._db.connection() as conn: - with conn.exec_statement('get-cache-item', params) as cur: + with self._db.session(False) as conn: + with conn.run('get-cache-item', params) as cur: if not (row := cur.one()): raise KeyError(f'{namespace}:{key}') @@ -178,14 +178,14 @@ class SqlCache(Cache): def get_keys(self, namespace: str) -> Iterator[str]: - with self._db.connection() as conn: - for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}): + with self._db.session(False) as conn: + for row in conn.run('get-cache-keys', {'namespace': namespace}): yield row['key'] def get_namespaces(self) -> Iterator[str]: - with self._db.connection() as conn: - for row in conn.exec_statement('get-cache-namespaces', None): + with self._db.session(False) as conn: + for row in conn.run('get-cache-namespaces', None): yield row['namespace'] @@ -198,8 +198,8 @@ class SqlCache(Cache): 'date': datetime.now(tz = timezone.utc) } - with self._db.connection() as conn: - with conn.exec_statement('set-cache-item', params) as conn: + with self._db.session(True) as conn: + with conn.run('set-cache-item', params) as conn: row = conn.one() row.pop('id', None) return Item.from_data(*tuple(row.values())) @@ -211,8 +211,8 @@ class SqlCache(Cache): 'key': key } - with self._db.connection() as conn: - with conn.exec_statement('del-cache-item', params): + with self._db.session(True) as conn: + with conn.run('del-cache-item', params): pass @@ -220,25 +220,27 @@ class SqlCache(Cache): limit = datetime.now(tz = timezone.utc) - timedelta(days = days) params = {"limit": limit.timestamp()} - with self._db.connection() as conn: + with self._db.session(True) as conn: with conn.execute("DELETE FROM cache WHERE updated < :limit", params): pass def clear(self) -> None: - with self._db.connection() as conn: + with self._db.session(True) as conn: with conn.execute("DELETE FROM cache"): pass def setup(self) -> None: - with self._db.connection() as conn: - with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None): + self._db.connect() + + with self._db.session(True) as conn: + with conn.run(f'create-cache-table-{self._db.backend_type.value}', None): pass def close(self) -> None: - self._db.close() + self._db.disconnect() self._db = None diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 2ddef35..bc06d25 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS cache ( UNIQUE(namespace, key) ); --- name: create-cache-table-postgres +-- name: create-cache-table-postgresql CREATE TABLE IF NOT EXISTS cache ( id SERIAL PRIMARY KEY, namespace TEXT NOT NULL, diff --git a/relay/database/__init__.py b/relay/database/__init__.py index facea97..c7e9a1f 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -1,49 +1,51 @@ from __future__ import annotations -import tinysql +import bsql import typing from .config import get_default_value -from .connection import Connection -from .schema import VERSIONS, migrate_0 +from .connection import RELAY_SOFTWARE, Connection +from .schema import TABLES, VERSIONS, migrate_0 from .. import logger as logging try: from importlib.resources import files as pkgfiles -except ImportError: +except ImportError: # pylint: disable=duplicate-code from importlib_resources import files as pkgfiles if typing.TYPE_CHECKING: from .config import Config -def get_database(config: Config, migrate: bool = True) -> tinysql.Database: +def get_database(config: Config, migrate: bool = True) -> bsql.Database: + options = { + "connection_class": Connection, + "pool_size": 5, + "tables": TABLES + } + if config.db_type == "sqlite": - db = tinysql.Database.sqlite( - config.sqlite_path, - connection_class = Connection, - min_connections = 2, - max_connections = 10 - ) + db = bsql.Database.sqlite(config.sqlite_path, **options) elif config.db_type == "postgres": - db = tinysql.Database.postgres( + db = bsql.Database.postgresql( config.pg_name, config.pg_host, config.pg_port, config.pg_user, config.pg_pass, - connection_class = Connection + **options ) db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) + db.connect() if not migrate: return db - with db.connection() as conn: + with db.session(True) as conn: if 'config' not in conn.get_tables(): logging.info("Creating database tables") migrate_0(conn) diff --git a/relay/database/connection.py b/relay/database/connection.py index 718861d..200e17e 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -1,9 +1,9 @@ from __future__ import annotations -import tinysql import typing from argon2 import PasswordHasher +from bsql import Connection as SqlConnection, Update from datetime import datetime, timezone from urllib.parse import urlparse from uuid import uuid4 @@ -15,7 +15,7 @@ from ..misc import get_app if typing.TYPE_CHECKING: from collections.abc import Iterator - from tinysql import Cursor, Row + from bsql import Row from typing import Any from .application import Application from ..misc import Message @@ -29,7 +29,7 @@ RELAY_SOFTWARE = [ ] -class Connection(tinysql.Connection): +class Connection(SqlConnection): hasher = PasswordHasher( encoding = 'utf-8' ) @@ -50,15 +50,11 @@ class Connection(tinysql.Connection): yield inbox['inbox'] - def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor: - return self.execute(self.database.prepared_statements[name], params) - - def get_config(self, key: str) -> Any: if key not in CONFIG_DEFAULTS: raise KeyError(key) - with self.exec_statement('get-config', {'key': key}) as cur: + with self.run('get-config', {'key': key}) as cur: if not (row := cur.one()): return get_default_value(key) @@ -69,7 +65,7 @@ class Connection(tinysql.Connection): def get_config_all(self) -> dict[str, Any]: - with self.exec_statement('get-config-all') as cur: + with self.run('get-config-all', None) as cur: db_config = {row['key']: row['value'] for row in cur} config = {} @@ -105,12 +101,12 @@ class Connection(tinysql.Connection): 'type': get_default_type(key) } - with self.exec_statement('put-config', params): + with self.run('put-config', params): return value def get_inbox(self, value: str) -> Row: - with self.exec_statement('get-inbox', {'value': value}) as cur: + with self.run('get-inbox', {'value': value}) as cur: return cur.one() @@ -130,7 +126,7 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-inbox', params) as cur: + with self.run('put-inbox', params) as cur: return cur.one() @@ -154,27 +150,28 @@ class Connection(tinysql.Connection): if software: data['software'] = software - statement = tinysql.Update('inboxes', data, inbox = inbox) + statement = Update('inboxes', data) + statement.set_where("inbox", inbox) with self.query(statement): return self.get_inbox(inbox) def del_inbox(self, value: str) -> bool: - with self.exec_statement('del-inbox', {'value': value}) as cur: - if cur.modified_row_count > 1: + with self.run('del-inbox', {'value': value}) as cur: + if cur.row_count > 1: raise ValueError('More than one row was modified') - return cur.modified_row_count == 1 + return cur.row_count == 1 def get_user(self, value: str) -> Row: - with self.exec_statement('get-user', {'value': value}) as cur: + with self.run('get-user', {'value': value}) as cur: return cur.one() def get_user_by_token(self, code: str) -> Row: - with self.exec_statement('get-user-by-token', {'code': code}) as cur: + with self.run('get-user-by-token', {'code': code}) as cur: return cur.one() @@ -186,22 +183,22 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-user', data) as cur: + with self.run('put-user', data) as cur: return cur.one() def del_user(self, username: str) -> None: user = self.get_user(username) - with self.exec_statement('del-user', {'value': user['username']}): + with self.run('del-user', {'value': user['username']}): pass - with self.exec_statement('del-token-user', {'username': user['username']}): + with self.run('del-token-user', {'username': user['username']}): pass def get_token(self, code: str) -> Row: - with self.exec_statement('get-token', {'code': code}) as cur: + with self.run('get-token', {'code': code}) as cur: return cur.one() @@ -212,12 +209,12 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-token', data) as cur: + with self.run('put-token', data) as cur: return cur.one() def del_token(self, code: str) -> None: - with self.exec_statement('del-token', {'code': code}): + with self.run('del-token', {'code': code}): pass @@ -225,7 +222,7 @@ class Connection(tinysql.Connection): if domain.startswith('http'): domain = urlparse(domain).netloc - with self.exec_statement('get-domain-ban', {'domain': domain}) as cur: + with self.run('get-domain-ban', {'domain': domain}) as cur: return cur.one() @@ -241,14 +238,14 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-domain-ban', params) as cur: + with self.run('put-domain-ban', params) as cur: return cur.one() def update_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> tinysql.Row: + note: str | None = None) -> Row: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -261,25 +258,26 @@ class Connection(tinysql.Connection): if note: params['note'] = note - statement = tinysql.Update('domain_bans', params, domain = domain) + statement = Update('domain_bans', params) + statement.set_where("domain", domain) with self.query(statement) as cur: - if cur.modified_row_count > 1: + if cur.row_count > 1: raise ValueError('More than one row was modified') return self.get_domain_ban(domain) def del_domain_ban(self, domain: str) -> bool: - with self.exec_statement('del-domain-ban', {'domain': domain}) as cur: - if cur.modified_row_count > 1: + with self.run('del-domain-ban', {'domain': domain}) as cur: + if cur.row_count > 1: raise ValueError('More than one row was modified') - return cur.modified_row_count == 1 + return cur.row_count == 1 def get_software_ban(self, name: str) -> Row: - with self.exec_statement('get-software-ban', {'name': name}) as cur: + with self.run('get-software-ban', {'name': name}) as cur: return cur.one() @@ -295,14 +293,14 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-software-ban', params) as cur: + with self.run('put-software-ban', params) as cur: return cur.one() def update_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> tinysql.Row: + note: str | None = None) -> Row: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -315,25 +313,26 @@ class Connection(tinysql.Connection): if note: params['note'] = note - statement = tinysql.Update('software_bans', params, name = name) + statement = Update('software_bans', params) + statement.set_where("name", name) with self.query(statement) as cur: - if cur.modified_row_count > 1: + if cur.row_count > 1: raise ValueError('More than one row was modified') return self.get_software_ban(name) def del_software_ban(self, name: str) -> bool: - with self.exec_statement('del-software-ban', {'name': name}) as cur: - if cur.modified_row_count > 1: + with self.run('del-software-ban', {'name': name}) as cur: + if cur.row_count > 1: raise ValueError('More than one row was modified') - return cur.modified_row_count == 1 + return cur.row_count == 1 def get_domain_whitelist(self, domain: str) -> Row: - with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur: + with self.run('get-domain-whitelist', {'domain': domain}) as cur: return cur.one() @@ -343,13 +342,13 @@ class Connection(tinysql.Connection): 'created': datetime.now(tz = timezone.utc) } - with self.exec_statement('put-domain-whitelist', params) as cur: + with self.run('put-domain-whitelist', params) as cur: return cur.one() def del_domain_whitelist(self, domain: str) -> bool: - with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur: - if cur.modified_row_count > 1: + with self.run('del-domain-whitelist', {'domain': domain}) as cur: + if cur.row_count > 1: raise ValueError('More than one row was modified') - return cur.modified_row_count == 1 + return cur.row_count == 1 diff --git a/relay/database/schema.py b/relay/database/schema.py index d4c51a4..e3a0303 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing -from tinysql import Column, Connection, Table +from bsql import Column, Connection, Table, Tables from .config import get_default_value @@ -11,7 +11,7 @@ if typing.TYPE_CHECKING: VERSIONS: dict[int, Callable] = {} -TABLES: list[Table] = [ +TABLES: Tables = Tables( Table( 'config', Column('key', 'text', primary_key = True, unique = True, nullable = False), @@ -59,7 +59,7 @@ TABLES: list[Table] = [ Column('user', 'text', nullable = False), Column('created', 'timestmap', nullable = False) ) -] +) def migration(func: Callable) -> Callable: @@ -69,10 +69,10 @@ def migration(func: Callable) -> Callable: def migrate_0(conn: Connection) -> None: - conn.create_tables(TABLES) + conn.create_tables() conn.put_config('schema-version', get_default_value('schema-version')) @migration def migrate_20240206(conn: Connection) -> None: - conn.create_tables(TABLES) + conn.create_tables() diff --git a/relay/http_client.py b/relay/http_client.py index 6bed4ae..7e7bbd9 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -111,7 +111,7 @@ class HttpClient: return loads(item.value) except KeyError: - logging.verbose('Failed to fetch cached data for url: %s', url) + logging.verbose('No cached data for url: %s', url) headers = {} diff --git a/relay/manage.py b/relay/manage.py index b767f00..3acd1c2 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -3,6 +3,7 @@ from __future__ import annotations import Crypto import asyncio import click +import os import platform import subprocess import sys @@ -19,8 +20,7 @@ from . import http_client as http from . import logger as logging from .application import Application from .compat import RelayConfig, RelayDatabase -from .database import get_database -from .database.connection import RELAY_SOFTWARE +from .database import RELAY_SOFTWARE, get_database from .misc import IS_DOCKER, Message if typing.TYPE_CHECKING: @@ -199,7 +199,7 @@ def cli_setup(ctx: click.Context) -> None: 'private-key': Signer.new('n/a').export() } - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for key, value in config.items(): conn.put_config(key, value) @@ -239,9 +239,12 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None: if getattr(sys, 'frozen', False): subprocess.run([sys.executable, 'run-gunicorn'], check = False) - return - ctx.obj.run(dev) + else: + ctx.obj.run(dev) + + # todo: figure out why the relay doesn't quit properly without this + os._exit(0) @cli.command('run-gunicorn') @@ -279,7 +282,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: ctx.obj.config.save() with get_database(ctx.obj.config) as db: - with db.connection() as conn: + with db.session(True) as conn: conn.put_config('private-key', database['private-key']) conn.put_config('note', config['note']) conn.put_config('whitelist-enabled', config['whitelist_enabled']) @@ -365,7 +368,7 @@ def cli_config_list(ctx: click.Context) -> None: click.echo('Relay Config:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for key, value in conn.get_config_all().items(): if key not in CONFIG_IGNORE: key = f'{key}:'.ljust(20) @@ -379,7 +382,7 @@ def cli_config_list(ctx: click.Context) -> None: def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: 'Set a config value' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: new_value = conn.put_config(key, value) print(f'{key}: {repr(new_value)}') @@ -397,7 +400,7 @@ def cli_user_list(ctx: click.Context) -> None: click.echo('Users:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for user in conn.execute('SELECT * FROM users'): click.echo(f'- {user["username"]}') @@ -409,7 +412,7 @@ def cli_user_list(ctx: click.Context) -> None: def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: 'Create a new local user' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_user(username): click.echo(f'User already exists: {username}') return @@ -436,7 +439,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: def cli_user_delete(ctx: click.Context, username: str) -> None: 'Delete a local user' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not conn.get_user(username): click.echo(f'User does not exist: {username}') return @@ -454,7 +457,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None: click.echo(f'Tokens for "{username}":') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): click.echo(f'- {token["code"]}') @@ -465,7 +468,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None: def cli_user_create_token(ctx: click.Context, username: str) -> None: 'Create a new API token for a user' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not (user := conn.get_user(username)): click.echo(f'User does not exist: {username}') return @@ -481,7 +484,7 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None: def cli_user_delete_token(ctx: click.Context, code: str) -> None: 'Delete an API token' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not conn.get_token(code): click.echo('Token does not exist') return @@ -503,7 +506,7 @@ def cli_inbox_list(ctx: click.Context) -> None: click.echo('Connected to the following instances or relays:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for inbox in conn.execute('SELECT * FROM inboxes'): click.echo(f'- {inbox["inbox"]}') @@ -514,7 +517,7 @@ def cli_inbox_list(ctx: click.Context) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None: 'Follow an actor (Relay must be running)' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return @@ -549,7 +552,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: inbox_data: Row = None - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return @@ -617,7 +620,7 @@ def cli_inbox_add( except KeyError: pass - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_domain_ban(domain): click.echo(f'Refusing to add banned inbox: {inbox}') return @@ -637,7 +640,7 @@ def cli_inbox_add( def cli_inbox_remove(ctx: click.Context, inbox: str) -> None: 'Remove an inbox from the database' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not conn.del_inbox(inbox): click.echo(f'Inbox not in database: {inbox}') return @@ -657,7 +660,7 @@ def cli_instance_list(ctx: click.Context) -> None: click.echo('Banned domains:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for instance in conn.execute('SELECT * FROM domain_bans'): if instance['reason']: click.echo(f'- {instance["domain"]} ({instance["reason"]})') @@ -674,7 +677,7 @@ def cli_instance_list(ctx: click.Context) -> None: def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None: 'Ban an instance and remove the associated inbox if it exists' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_domain_ban(domain): click.echo(f'Domain already banned: {domain}') return @@ -690,7 +693,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> def cli_instance_unban(ctx: click.Context, domain: str) -> None: 'Unban an instance' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not conn.del_domain_ban(domain): click.echo(f'Instance wasn\'t banned: {domain}') return @@ -709,7 +712,7 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) if not (reason or note): ctx.fail('Must pass --reason or --note') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not (row := conn.update_domain_ban(domain, reason, note)): click.echo(f'Failed to update domain ban: {domain}') return @@ -735,7 +738,7 @@ def cli_software_list(ctx: click.Context) -> None: click.echo('Banned software:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for software in conn.execute('SELECT * FROM software_bans'): if software['reason']: click.echo(f'- {software["name"]} ({software["reason"]})') @@ -761,7 +764,7 @@ def cli_software_ban(ctx: click.Context, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to ban relays' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if name == 'RELAYS': for software in RELAY_SOFTWARE: if conn.get_software_ban(software): @@ -804,7 +807,7 @@ def cli_software_ban(ctx: click.Context, def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to unban relays' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if name == 'RELAYS': for software in RELAY_SOFTWARE: if not conn.del_software_ban(software): @@ -838,7 +841,7 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) - if not (reason or note): ctx.fail('Must pass --reason or --note') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not (row := conn.update_software_ban(name, reason, note)): click.echo(f'Failed to update software ban: {name}') return @@ -864,7 +867,7 @@ def cli_whitelist_list(ctx: click.Context) -> None: click.echo('Current whitelisted domains:') - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for domain in conn.execute('SELECT * FROM whitelist'): click.echo(f'- {domain["domain"]}') @@ -875,7 +878,7 @@ def cli_whitelist_list(ctx: click.Context) -> None: def cli_whitelist_add(ctx: click.Context, domain: str) -> None: 'Add a domain to the whitelist' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if conn.get_domain_whitelist(domain): click.echo(f'Instance already in the whitelist: {domain}') return @@ -890,7 +893,7 @@ def cli_whitelist_add(ctx: click.Context, domain: str) -> None: def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: 'Remove an instance from the whitelist' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: if not conn.del_domain_whitelist(domain): click.echo(f'Domain not in the whitelist: {domain}') return @@ -907,7 +910,7 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: def cli_whitelist_import(ctx: click.Context) -> None: 'Add all current inboxes to the whitelist' - with ctx.obj.database.connection() as conn: + with ctx.obj.database.session() as conn: for inbox in conn.execute('SELECT * FROM inboxes').all(): if conn.get_domain_whitelist(inbox['domain']): click.echo(f'Domain already in whitelist: {inbox["domain"]}') diff --git a/relay/processors.py b/relay/processors.py index d8b32fe..824a975 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing from . import logger as logging -from .database.connection import Connection +from .database import Connection from .misc import Message if typing.TYPE_CHECKING: @@ -43,8 +43,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None: async def handle_forward(view: ActorView, conn: Connection) -> None: try: - view.cache.get('handle-relay', view.message.object_id) - logging.verbose('already forwarded %s', view.message.object_id) + view.cache.get('handle-relay', view.message.id) + logging.verbose('already forwarded %s', view.message.id) return except KeyError: @@ -56,7 +56,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: for inbox in conn.distill_inboxes(view.message): view.app.push_message(inbox, message, view.instance) - view.cache.set('handle-relay', view.message.object_id, message.id, 'str') + view.cache.set('handle-relay', view.message.id, message.id, 'str') async def handle_follow(view: ActorView, conn: Connection) -> None: @@ -184,7 +184,7 @@ async def run_processor(view: ActorView) -> None: return - with view.database.connection(False) as conn: + with view.database.session() as conn: if view.instance: if not view.instance['software']: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index e95f02e..08cbf2f 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -46,7 +46,7 @@ class ActorView(View): if response := await self.get_post_data(): return response - with self.database.connection(False) as conn: + with self.database.session() as conn: self.instance = conn.get_inbox(self.actor.shared_inbox) config = conn.get_config_all() diff --git a/relay/views/api.py b/relay/views/api.py index 84774c7..56a5a25 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -45,7 +45,7 @@ async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Respo try: request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() - with request.app.database.connection() as conn: + with request.app.database.session() as conn: request['user'] = conn.get_user_by_token(request['token']) except (KeyError, ValueError): @@ -92,7 +92,7 @@ class Login(View): async def delete(self, request: Request) -> Response: - with self.database.connection(True) as conn: + with self.database.session() as conn: conn.del_token(request['token']) return Response.new({'message': 'Token revoked'}, ctype = 'json') @@ -101,7 +101,7 @@ class Login(View): @register_route('/api/v1/relay') class RelayInfo(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: config = conn.get_config_all() inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] @@ -123,7 +123,7 @@ class RelayInfo(View): @register_route('/api/v1/config') class Config(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: data = conn.get_config_all() data['log-level'] = data['log-level'].name @@ -142,7 +142,7 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with self.database.connection(True) as conn: + with self.database.session() as conn: conn.put_config(data['key'], data['value']) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -157,7 +157,7 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with self.database.connection(True) as conn: + with self.database.session() as conn: conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -168,15 +168,9 @@ class Inbox(View): async def get(self, request: Request) -> Response: data = [] - with self.database.connection(False) as conn: + with self.database.session() as conn: for inbox in conn.execute('SELECT * FROM inboxes'): - try: - created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) - - except TypeError: - created = datetime.fromisoformat(inbox['created']) - - inbox['created'] = created.isoformat() + inbox['created'] = inbox['created'].isoformat() data.append(inbox) return Response.new(data, ctype = 'json') @@ -190,7 +184,7 @@ class Inbox(View): data['domain'] = urlparse(data["actor"]).netloc - with self.database.connection(True) as conn: + with self.database.session() as conn: if conn.get_inbox(data['domain']): return Response.new_error(404, 'Instance already in database', 'json') @@ -214,7 +208,7 @@ class Inbox(View): async def patch(self, request: Request) -> Response: - with self.database.connection(True) as conn: + with self.database.session() as conn: data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) if isinstance(data, Response): @@ -229,7 +223,7 @@ class Inbox(View): async def delete(self, request: Request, domain: str) -> Response: - with self.database.connection(True) as conn: + with self.database.session() as conn: data = await self.get_api_data(['domain'], []) if isinstance(data, Response): @@ -246,7 +240,7 @@ class Inbox(View): @register_route('/api/v1/domain_ban') class DomainBan(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: bans = conn.execute('SELECT * FROM domain_bans').all() return Response.new(bans, ctype = 'json') @@ -258,7 +252,7 @@ class DomainBan(View): if isinstance(data, Response): return data - with self.database.connection(True) as conn: + with self.database.session() as conn: if conn.get_domain_ban(data['domain']): return Response.new_error(400, 'Domain already banned', 'json') @@ -268,7 +262,7 @@ class DomainBan(View): async def patch(self, request: Request) -> Response: - with self.database.connection(True) as conn: + with self.database.session() as conn: data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): @@ -286,7 +280,7 @@ class DomainBan(View): async def delete(self, request: Request) -> Response: - with self.database.connection(True) as conn: + with self.database.session() as conn: data = await self.get_api_data(['domain'], []) if isinstance(data, Response): @@ -303,7 +297,7 @@ class DomainBan(View): @register_route('/api/v1/software_ban') class SoftwareBan(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: bans = conn.execute('SELECT * FROM software_bans').all() return Response.new(bans, ctype = 'json') @@ -315,7 +309,7 @@ class SoftwareBan(View): if isinstance(data, Response): return data - with self.database.connection(True) as conn: + with self.database.session() as conn: if conn.get_software_ban(data['name']): return Response.new_error(400, 'Domain already banned', 'json') @@ -330,7 +324,7 @@ class SoftwareBan(View): if isinstance(data, Response): return data - with self.database.connection(True) as conn: + with self.database.session() as conn: if not conn.get_software_ban(data['name']): return Response.new_error(404, 'Software not banned', 'json') @@ -348,7 +342,7 @@ class SoftwareBan(View): if isinstance(data, Response): return data - with self.database.connection(True) as conn: + with self.database.session() as conn: if not conn.get_software_ban(data['name']): return Response.new_error(404, 'Software not banned', 'json') @@ -360,7 +354,7 @@ class SoftwareBan(View): @register_route('/api/v1/whitelist') class Whitelist(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: items = conn.execute('SELECT * FROM whitelist').all() return Response.new(items, ctype = 'json') @@ -372,7 +366,7 @@ class Whitelist(View): if isinstance(data, Response): return data - with self.database.connection(True) as conn: + with self.database.session() as conn: if conn.get_domain_whitelist(data['domain']): return Response.new_error(400, 'Domain already added to whitelist', 'json') @@ -387,7 +381,7 @@ class Whitelist(View): if isinstance(data, Response): return data - with self.database.connection(False) as conn: + with self.database.session() as conn: if not conn.get_domain_whitelist(data['domain']): return Response.new_error(404, 'Domain not in whitelist', 'json') diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 00af6fb..fb6028f 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -44,9 +44,9 @@ HOME_TEMPLATE = """ @register_route('/') class HomeView(View): async def get(self, request: Request) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() + inboxes = tuple(conn.execute('SELECT * FROM inboxes').all()) text = HOME_TEMPLATE.format( host = self.config.domain, diff --git a/relay/views/misc.py b/relay/views/misc.py index 3a84436..bede27d 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -33,7 +33,7 @@ if Path(__file__).parent.parent.joinpath('.git').exists(): class NodeinfoView(View): # pylint: disable=no-self-use async def get(self, request: Request, niversion: str) -> Response: - with self.database.connection(False) as conn: + with self.database.session() as conn: inboxes = conn.execute('SELECT * FROM inboxes').all() data = { diff --git a/requirements.txt b/requirements.txt index 46a6a9a..65b70b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,6 @@ gunicorn==21.1.0 hiredis==2.3.2 pyyaml>=6.0 redis==5.0.1 -tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/main.tar.gz +barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.1.tar.gz importlib_resources==6.1.1;python_version<'3.9' diff --git a/setup.cfg b/setup.cfg index 72ac7dc..685f357 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,4 +45,4 @@ console_scripts = [flake8] select = F401 per-file-ignores = - relay/views/__init__.py: F401 + __init__.py: F401