From 226f940cdcacf1c9ff9be047ac7af1bb326c5223 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Mon, 5 Feb 2024 14:11:10 -0500 Subject: [PATCH 1/7] split views.py --- relay/misc.py | 68 -------------- relay/views/__init__.py | 4 + relay/{views.py => views/activitypub.py} | 110 ++--------------------- relay/views/base.py | 93 +++++++++++++++++++ relay/views/frontend.py | 62 +++++++++++++ relay/views/misc.py | 58 ++++++++++++ 6 files changed, 224 insertions(+), 171 deletions(-) create mode 100644 relay/views/__init__.py rename relay/{views.py => views/activitypub.py} (65%) create mode 100644 relay/views/base.py create mode 100644 relay/views/frontend.py create mode 100644 relay/views/misc.py diff --git a/relay/misc.py b/relay/misc.py index e71845d..296082b 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -5,12 +5,8 @@ import os import socket import typing -from aiohttp.abc import AbstractView -from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import Response as AiohttpResponse -from aiohttp.web_exceptions import HTTPMethodNotAllowed from aputils.message import Message as ApMessage -from functools import cached_property from uuid import uuid4 if typing.TYPE_CHECKING: @@ -232,67 +228,3 @@ class Response(AiohttpResponse): @location.setter def location(self, value: str) -> None: self.headers['Location'] = value - - -class View(AbstractView): - def __await__(self) -> Generator[Response]: - if self.request.method not in METHODS: - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) - - if not (handler := self.handlers.get(self.request.method)): - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None - - return self._run_handler(handler).__await__() - - - async def _run_handler(self, handler: Awaitable) -> Response: - with self.database.config.connection_class(self.database) as conn: - # todo: remove on next tinysql release - conn.open() - - return await handler(self.request, conn, **self.request.match_info) - - - @cached_property - def allowed_methods(self) -> tuple[str]: - return tuple(self.handlers.keys()) - - - @cached_property - def handlers(self) -> dict[str, Coroutine]: - data = {} - - for method in METHODS: - try: - data[method] = getattr(self, method.lower()) - - except AttributeError: - continue - - return data - - - # app components - @property - def app(self) -> Application: - return self.request.app - - - @property - def cache(self) -> Cache: - return self.app.cache - - - @property - def client(self) -> HttpClient: - return self.app.client - - - @property - def config(self) -> Config: - return self.app.config - - - @property - def database(self) -> Database: - return self.app.database diff --git a/relay/views/__init__.py b/relay/views/__init__.py new file mode 100644 index 0000000..85c2bd3 --- /dev/null +++ b/relay/views/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from . import activitypub, frontend, misc +from .base import VIEWS diff --git a/relay/views.py b/relay/views/activitypub.py similarity index 65% rename from relay/views.py rename to relay/views/activitypub.py index cb648a2..be51047 100644 --- a/relay/views.py +++ b/relay/views/activitypub.py @@ -1,95 +1,27 @@ from __future__ import annotations -import subprocess import traceback import typing from aputils.errors import SignatureFailureError from aputils.misc import Digest, HttpDate, Signature -from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo -from pathlib import Path +from aputils.objects import Webfinger -from . import __version__ -from . import logger as logging -from .database.connection import Connection -from .misc import Message, Response, View -from .processors import run_processor +from .base import View, register_route + +from .. import logger as logging +from ..misc import Message, Response +from ..processors import run_processor if typing.TYPE_CHECKING: from aiohttp.web import Request from aputils.signer import Signer - from collections.abc import Callable from tinysql import Row - - -VIEWS = [] -VERSION = __version__ -HOME_TEMPLATE = """ - - ActivityPub Relay at {host} - - - -

This is an Activity Relay for fediverse instances.

-

{note}

-

- You may subscribe to this relay with the address: - https://{host}/actor -

-

- To host your own relay, you may download the code at this address: - - https://git.pleroma.social/pleroma/relay - -

-

List of {count} registered instances:
{targets}

- -""" - - -if Path(__file__).parent.parent.joinpath('.git').exists(): - try: - commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') - VERSION = f'{__version__} {commit_label}' - - except Exception: - pass - - -def register_route(*paths: str) -> Callable: - def wrapper(view: View) -> View: - for path in paths: - VIEWS.append([path, view]) - - return View - return wrapper + from ..database.connection import Connection # pylint: disable=unused-argument -@register_route('/') -class HomeView(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() - - text = HOME_TEMPLATE.format( - host = self.config.domain, - note = config['note'], - count = len(inboxes), - targets = '
'.join(inbox['domain'] for inbox in inboxes) - ) - - return Response.new(text, ctype='html') - - - @register_route('/actor', '/inbox') class ActorView(View): def __init__(self, request: Request): @@ -247,31 +179,3 @@ class WebfingerView(View): ) return Response.new(data, ctype = 'json') - - -@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') -class NodeinfoView(View): - # pylint: disable=no-self-use - async def get(self, request: Request, conn: Connection, niversion: str) -> Response: - inboxes = conn.execute('SELECT * FROM inboxes').all() - - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } - - if niversion == '2.1': - data['repo'] = 'https://git.pleroma.social/pleroma/relay' - - return Response.new(Nodeinfo.new(**data), ctype = 'json') - - -@register_route('/.well-known/nodeinfo') -class WellknownNodeinfoView(View): - async def get(self, request: Request, conn: Connection) -> Response: - data = WellKnownNodeinfo.new_template(self.config.domain) - return Response.new(data, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py new file mode 100644 index 0000000..95b6562 --- /dev/null +++ b/relay/views/base.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import typing + +from aiohttp.abc import AbstractView +from aiohttp.hdrs import METH_ALL as METHODS +from aiohttp.web import HTTPMethodNotAllowed +from functools import cached_property + +if typing.TYPE_CHECKING: + from collections.abc import Callable, Coroutine, Generator + from tinysql import Database + from ..application import Application + from ..cache import Cache + from ..config import Config + from ..http_client import HttpClient + from ..misc import Response + + +VIEWS = [] + + +def register_route(*paths: str) -> Callable: + def wrapper(view: View) -> View: + for path in paths: + VIEWS.append([path, view]) + + return View + return wrapper + + +class View(AbstractView): + def __await__(self) -> Generator[Response]: + if self.request.method not in METHODS: + raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + + if not (handler := self.handlers.get(self.request.method)): + raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + + return self._run_handler(handler).__await__() + + + async def _run_handler(self, handler: Coroutine) -> Response: + with self.database.config.connection_class(self.database) as conn: + # todo: remove on next tinysql release + conn.open() + + return await handler(self.request, conn, **self.request.match_info) + + + @cached_property + def allowed_methods(self) -> tuple[str]: + return tuple(self.handlers.keys()) + + + @cached_property + def handlers(self) -> dict[str, Coroutine]: + data = {} + + for method in METHODS: + try: + data[method] = getattr(self, method.lower()) + + except AttributeError: + continue + + return data + + + # app components + @property + def app(self) -> Application: + return self.request.app + + + @property + def cache(self) -> Cache: + return self.app.cache + + + @property + def client(self) -> HttpClient: + return self.app.client + + + @property + def config(self) -> Config: + return self.app.config + + + @property + def database(self) -> Database: + return self.app.database diff --git a/relay/views/frontend.py b/relay/views/frontend.py new file mode 100644 index 0000000..987b9b0 --- /dev/null +++ b/relay/views/frontend.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing + +from .base import View, register_route + +from .. import __version__ +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from aputils.signer import Signer + from collections.abc import Callable + from tinysql import Row + from ..database.connection import Connection + + +HOME_TEMPLATE = """ + + ActivityPub Relay at {host} + + + +

This is an Activity Relay for fediverse instances.

+

{note}

+

+ You may subscribe to this relay with the address: + https://{host}/actor +

+

+ To host your own relay, you may download the code at this address: + + https://git.pleroma.social/pleroma/relay + +

+

List of {count} registered instances:
{targets}

+ +""" + + +# pylint: disable=unused-argument + +@register_route('/') +class HomeView(View): + async def get(self, request: Request, conn: Connection) -> Response: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() + + text = HOME_TEMPLATE.format( + host = self.config.domain, + note = config['note'], + count = len(inboxes), + targets = '
'.join(inbox['domain'] for inbox in inboxes) + ) + + return Response.new(text, ctype='html') diff --git a/relay/views/misc.py b/relay/views/misc.py new file mode 100644 index 0000000..e41ae2b --- /dev/null +++ b/relay/views/misc.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import subprocess +import typing + +from aputils.objects import Nodeinfo, WellKnownNodeinfo +from pathlib import Path + +from .base import View, register_route + +from .. import __version__ +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from ..database.connection import Connection + + +VERSION = __version__ + + +if Path(__file__).parent.parent.joinpath('.git').exists(): + try: + commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') + VERSION = f'{__version__} {commit_label}' + + except Exception: + pass + + +# pylint: disable=unused-argument + +@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') +class NodeinfoView(View): + # pylint: disable=no-self-use + async def get(self, request: Request, conn: Connection, niversion: str) -> Response: + inboxes = conn.execute('SELECT * FROM inboxes').all() + + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } + + if niversion == '2.1': + data['repo'] = 'https://git.pleroma.social/pleroma/relay' + + return Response.new(Nodeinfo.new(**data), ctype = 'json') + + +@register_route('/.well-known/nodeinfo') +class WellknownNodeinfoView(View): + async def get(self, request: Request, conn: Connection) -> Response: + data = WellKnownNodeinfo.new_template(self.config.domain) + return Response.new(data, ctype = 'json') From d10c864a00a46ee4e0ad626ae6730da285e9e619 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 9 Feb 2024 16:25:34 -0500 Subject: [PATCH 2/7] add public api views --- relay/application.py | 9 +- relay/data/statements.sql | 63 ++++++- relay/database/__init__.py | 6 +- relay/database/config.py | 3 +- relay/database/connection.py | 59 ++++++ relay/database/schema.py | 24 ++- relay/manage.py | 110 ++++++++++++ relay/misc.py | 2 +- relay/views/__init__.py | 2 +- relay/views/api.py | 338 +++++++++++++++++++++++++++++++++++ relay/views/base.py | 33 +++- requirements.txt | 1 + 12 files changed, 627 insertions(+), 23 deletions(-) create mode 100644 relay/views/api.py diff --git a/relay/application.py b/relay/application.py index 80efed9..20695d7 100644 --- a/relay/application.py +++ b/relay/application.py @@ -20,6 +20,7 @@ from .database import get_database from .http_client import HttpClient from .misc import check_open_port from .views import VIEWS +from .views.api import handle_api_path if typing.TYPE_CHECKING: from collections.abc import Awaitable @@ -35,7 +36,11 @@ class Application(web.Application): DEFAULT: Application = None def __init__(self, cfgpath: str, gunicorn: bool = False): - web.Application.__init__(self) + web.Application.__init__(self, + middlewares = [ + handle_api_path + ] + ) Application.DEFAULT = self @@ -219,6 +224,6 @@ async def main_gunicorn(): except KeyError: logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?') - raise + raise RuntimeError from None return app diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 9bcea41..0f98f64 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -1,47 +1,92 @@ -- name: get-config -SELECT * FROM config WHERE key = :key +SELECT * FROM config WHERE key = :key; -- name: get-config-all -SELECT * FROM config +SELECT * FROM config; -- name: put-config INSERT INTO config (key, value, type) VALUES (:key, :value, :type) ON CONFLICT (key) DO UPDATE SET value = :value -RETURNING * +RETURNING *; -- name: del-config DELETE FROM config -WHERE key = :key +WHERE key = :key; -- name: get-inbox -SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value +SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value; -- name: put-inbox INSERT INTO inboxes (domain, actor, inbox, followid, software, created) VALUES (:domain, :actor, :inbox, :followid, :software, :created) ON CONFLICT (domain) DO UPDATE SET followid = :followid -RETURNING * +RETURNING *; -- name: del-inbox DELETE FROM inboxes -WHERE domain = :value or inbox = :value or actor = :value +WHERE domain = :value or inbox = :value or actor = :value; + + +-- name: get-user +SELECT * FROM users +WHERE username = :value or handle = :value; + + +-- name: get-user-by-token +SELECT * FROM users +WHERE username = ( + SELECT user FROM tokens + WHERE code = :code +); + + +-- name: put-user +INSERT INTO users (username, hash, handle, created) +VALUES (:username, :hash, :handle, :created) +RETURNING *; + + +-- name: del-user +DELETE FROM users +WHERE username = :value or handle = :value; + + +-- name: get-token +SELECT * FROM tokens +WHERE code = :code; + + +-- name: put-token +INSERT INTO tokens (code, user, created) +VALUES (:code, :user, :created) +RETURNING *; + + +-- name: del-token +DELETE FROM tokens +WHERE code = :code; + + +-- name: del-token-user +DELETE FROM tokens +WHERE user = :username -- name: get-software-ban -SELECT * FROM software_bans WHERE name = :name +SELECT * FROM software_bans WHERE name = :name; -- name: put-software-ban INSERT INTO software_bans (name, reason, note, created) VALUES (:name, :reason, :note, :created) -RETURNING * +RETURNING *; -- name: del-software-ban diff --git a/relay/database/__init__.py b/relay/database/__init__.py index c403092..facea97 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -52,14 +52,10 @@ def get_database(config: Config, migrate: bool = True) -> tinysql.Database: if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'): logging.info("Migrating database from version '%i'", schema_ver) - for ver, func in VERSIONS: + for ver, func in VERSIONS.items(): if schema_ver < ver: - conn.begin() - func(conn) - conn.put_config('schema-version', ver) - conn.commit() if (privkey := conn.get_config('private-key')): conn.app.signer = privkey diff --git a/relay/database/config.py b/relay/database/config.py index 8fe3b4c..b69f13e 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -11,8 +11,9 @@ if typing.TYPE_CHECKING: CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { - 'schema-version': ('int', 20240119), + 'schema-version': ('int', 20240206), 'log-level': ('loglevel', logging.LogLevel.INFO), + 'name': ('str', 'ActivityRelay'), 'note': ('str', 'Make a note about your instance here.'), 'private-key': ('str', None), 'whitelist-enabled': ('bool', False) diff --git a/relay/database/connection.py b/relay/database/connection.py index ac5a364..718861d 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -3,8 +3,10 @@ from __future__ import annotations import tinysql import typing +from argon2 import PasswordHasher from datetime import datetime, timezone from urllib.parse import urlparse +from uuid import uuid4 from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize @@ -28,6 +30,10 @@ RELAY_SOFTWARE = [ class Connection(tinysql.Connection): + hasher = PasswordHasher( + encoding = 'utf-8' + ) + @property def app(self) -> Application: return get_app() @@ -162,6 +168,59 @@ class Connection(tinysql.Connection): return cur.modified_row_count == 1 + def get_user(self, value: str) -> Row: + with self.exec_statement('get-user', {'value': value}) as cur: + return cur.one() + + + def get_user_by_token(self, code: str) -> Row: + with self.exec_statement('get-user-by-token', {'code': code}) as cur: + return cur.one() + + + def put_user(self, username: str, password: str, handle: str | None = None) -> Row: + data = { + 'username': username, + 'hash': self.hasher.hash(password), + 'handle': handle, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-user', data) as cur: + return cur.one() + + + def del_user(self, username: str) -> None: + user = self.get_user(username) + + with self.exec_statement('del-user', {'value': user['username']}): + pass + + with self.exec_statement('del-token-user', {'username': user['username']}): + pass + + + def get_token(self, code: str) -> Row: + with self.exec_statement('get-token', {'code': code}) as cur: + return cur.one() + + + def put_token(self, username: str) -> Row: + data = { + 'code': uuid4().hex, + 'user': username, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-token', data) as cur: + return cur.one() + + + def del_token(self, code: str) -> None: + with self.exec_statement('del-token', {'code': code}): + pass + + def get_domain_ban(self, domain: str) -> Row: if domain.startswith('http'): domain = urlparse(domain).netloc diff --git a/relay/database/schema.py b/relay/database/schema.py index 7b6d927..d4c51a4 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -10,7 +10,7 @@ if typing.TYPE_CHECKING: from collections.abc import Callable -VERSIONS: list[Callable] = [] +VERSIONS: dict[int, Callable] = {} TABLES: list[Table] = [ Table( 'config', @@ -45,12 +45,25 @@ TABLES: list[Table] = [ Column('reason', 'text'), Column('note', 'text'), Column('created', 'timestamp', nullable = False) + ), + Table( + 'users', + Column('username', 'text', primary_key = True, unique = True, nullable = False), + Column('hash', 'text', nullable = False), + Column('handle', 'text'), + Column('created', 'timestamp', nullable = False) + ), + Table( + 'tokens', + Column('code', 'text', primary_key = True, unique = True, nullable = False), + Column('user', 'text', nullable = False), + Column('created', 'timestmap', nullable = False) ) ] -def version(func: Callable) -> Callable: - ver = int(func.replace('migrate_', '')) +def migration(func: Callable) -> Callable: + ver = int(func.__name__.replace('migrate_', '')) VERSIONS[ver] = func return func @@ -58,3 +71,8 @@ def version(func: Callable) -> Callable: def migrate_0(conn: Connection) -> None: conn.create_tables(TABLES) conn.put_config('schema-version', get_default_value('schema-version')) + + +@migration +def migrate_20240206(conn: Connection) -> None: + conn.create_tables(TABLES) diff --git a/relay/manage.py b/relay/manage.py index 2a78b6f..039a6f5 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -370,6 +370,116 @@ def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: print(f'{key}: {repr(new_value)}') +@cli.group('user') +def cli_user() -> None: + 'Manage local users' + + +@cli_user.command('list') +@click.pass_context +def cli_user_list(ctx: click.Context) -> None: + 'List all local users' + + click.echo('Users:') + + with ctx.obj.database.connection() as conn: + for user in conn.execute('SELECT * FROM users'): + click.echo(f'- {user["username"]}') + + +@cli_user.command('create') +@click.argument('username') +@click.argument('handle', required = False) +@click.pass_context +def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: + 'Create a new local user' + + with ctx.obj.database.connection() as conn: + if conn.get_user(username): + click.echo(f'User already exists: {username}') + return + + while True: + password = click.prompt('New password', hide_input = True) + + if not password: + click.echo('No password provided') + continue + + password2 = click.prompt('New password again', hide_input = True) + + if password != password2: + click.echo('Passwords do not match') + continue + + break + + conn.put_user(username, password, handle) + + click.echo(f'Created user "{username}"') + + +@cli_user.command('delete') +@click.argument('username') +@click.pass_context +def cli_user_delete(ctx: click.Context, username: str) -> None: + 'Delete a local user' + + with ctx.obj.database.connection() as conn: + if not conn.get_user(username): + click.echo(f'User does not exist: {username}') + return + + conn.del_user(username) + + click.echo(f'Deleted user "{username}"') + + +@cli_user.command('list-tokens') +@click.argument('username') +@click.pass_context +def cli_user_list_tokens(ctx: click.Context, username: str) -> None: + 'List all API tokens for a user' + + click.echo(f'Tokens for "{username}":') + + with ctx.obj.database.connection() as conn: + for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): + click.echo(f'- {token["code"]}') + + +@cli_user.command('create-token') +@click.argument('username') +@click.pass_context +def cli_user_create_token(ctx: click.Context, username: str) -> None: + 'Create a new API token for a user' + + with ctx.obj.database.connection() as conn: + if not (user := conn.get_user(username)): + click.echo(f'User does not exist: {username}') + return + + token = conn.put_token(user['username']) + + click.echo(f'New token for "{username}": {token["code"]}') + + +@cli_user.command('delete-token') +@click.argument('code') +@click.pass_context +def cli_user_delete_token(ctx: click.Context, code: str) -> None: + 'Delete an API token' + + with ctx.obj.database.connection() as conn: + if not (conn.get_token(code)): + click.echo('Token does not exist') + return + + conn.del_token(code) + + click.echo('Deleted token') + + @cli.group('inbox') def cli_inbox() -> None: 'Manage the inboxes in the database' diff --git a/relay/misc.py b/relay/misc.py index 296082b..bb78957 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -199,7 +199,7 @@ class Response(AiohttpResponse): if isinstance(body, bytes): kwargs['body'] = body - elif isinstance(body, dict) and ctype in {'json', 'activity'}: + elif isinstance(body, (dict, list, tuple, set)) and ctype in {'json', 'activity'}: kwargs['text'] = json.dumps(body) else: diff --git a/relay/views/__init__.py b/relay/views/__init__.py index 85c2bd3..6366592 100644 --- a/relay/views/__init__.py +++ b/relay/views/__init__.py @@ -1,4 +1,4 @@ from __future__ import annotations -from . import activitypub, frontend, misc +from . import activitypub, api, frontend, misc from .base import VIEWS diff --git a/relay/views/api.py b/relay/views/api.py new file mode 100644 index 0000000..f75caf4 --- /dev/null +++ b/relay/views/api.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import typing + +from aiohttp import web +from argon2.exceptions import VerifyMismatchError +from datetime import datetime, timezone + +from .base import View, register_route + +from .. import __version__ +from ..database.config import CONFIG_DEFAULTS +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from collections.abc import Coroutine + from ..database.connection import Connection + + +CONFIG_IGNORE = ( + 'schema-version', + 'private-key' +) + +CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} + +PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( + ('GET', '/api/v1/relay'), + ('POST', '/api/v1/token') +) + + +def check_api_path(method: str, path: str) -> bool: + for m, p in PUBLIC_API_PATHS: + if m == method and p == path: + return False + + return path.startswith('/api') + + +@web.middleware +async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response: + try: + request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() + + with request.app.database.connection() as conn: + request['user'] = conn.get_user_by_token(request['token']) + + except (KeyError, ValueError): + request['token'] = None + request['user'] = None + + if check_api_path(request.method, request.path): + if not request['token']: + return Response.new_error(401, 'Missing token', 'json') + + if not request['user']: + return Response.new_error(401, 'Invalid token', 'json') + + return await handler(request) + + +# pylint: disable=no-self-use,unused-argument + +@register_route('/api/v1/token') +class Login(View): + async def get(self, request: Request, conn: Connection) -> Response: + return Response.new({'message': 'Token valid :3'}) + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['username', 'password'], []) + + if isinstance(data, Response): + return data + + if not (user := conn.get_user(data['username'])): + return Response.new_error(401, 'User not found', 'json') + + try: + conn.hasher.verify(user['hash'], data['password']) + + except VerifyMismatchError: + return Response.new_error(401, 'Invalid password', 'json') + + token = conn.put_token(data['username']) + return Response.new({'token': token['code']}, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection) -> Response: + conn.del_token(request['token']) + return Response.new({'message': 'Token revoked'}, ctype = 'json') + + +@register_route('/api/v1/relay') +class RelayInfo(View): + async def get(self, request: Request, conn: Connection) -> Response: + config = conn.get_config_all() + inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + + data = { + 'domain': self.config.domain, + 'name': config['name'], + 'description': config['note'], + 'version': __version__, + 'whitelist_enabled': config['whitelist-enabled'], + 'email': None, + 'admin': None, + 'icon': None, + 'instances': inboxes + } + + return Response.new(data, ctype = 'json') + + +@register_route('/api/v1/config') +class Config(View): + async def get(self, request: Request, conn: Connection) -> Response: + data = conn.get_config_all() + data['log-level'] = data['log-level'].name + + for key in CONFIG_IGNORE: + del data[key] + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['key', 'value'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + conn.put_config(data['key'], data['value']) + return Response.new({'message': 'Updated config'}, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['key'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + return Response.new({'message': 'Updated config'}, ctype = 'json') + + +@register_route('/api/v1/inbox') +class Inbox(View): + async def get(self, request: Request, conn: Connection) -> Response: + data = [] + + for inbox in conn.execute('SELECT * FROM inboxes'): + try: + created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) + + except TypeError: + created = datetime.fromisoformat(inbox['created']) + + inbox['created'] = created.isoformat() + data.append(inbox) + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain', 'inbox', 'actor'], ['software']) + + if isinstance(data, Response): + return data + + if conn.get_inbox(data['domain']): + return Response.new_error(404, 'Inbox already in database', 'json') + + row = conn.put_inbox(**data) + return Response.new(row.to_json(), ctype = 'json') + + +@register_route('/api/v1/inbox/{domain}') +class InboxSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (row := conn.get_inbox(domain)): + return Response.new_error(404, 'Inbox with domain not found', 'json') + + row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() + return Response.new(row, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Inbox with domain not found', 'json') + + conn.del_inbox(domain) + return Response.new({'message': 'Deleted inbox'}, ctype = 'json') + + +@register_route('/api/v1/domain_ban') +class DomainBan(View): + async def get(self, request: Request, conn: Connection) -> Response: + bans = conn.execute('SELECT * FROM domain_bans').all() + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if conn.get_domain_ban(data['domain']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_domain_ban(**data) + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/domain_ban/{domain}') +class DomainBanSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (ban := conn.get_domain_ban(domain)): + return Response.new_error(404, 'Domain ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def patch(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + data = await self.get_data(['domain'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + ban = conn.update_domain_ban(**data) + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + conn.del_domain_ban(domain) + return Response.new({'message': 'Unbanned domain'}, ctype = 'json') + + +@register_route('/api/v1/software_ban') +class SoftwareBan(View): + async def get(self, request: Request, conn: Connection) -> Response: + bans = conn.execute('SELECT * FROM software_bans').all() + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['name'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if conn.get_software_ban(data['name']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_software_ban(**data) + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/software_ban/{name}') +class SoftwareBanSingle(View): + async def get(self, request: Request, conn: Connection, name: str) -> Response: + if not (ban := conn.get_software_ban(name)): + return Response.new_error(404, 'Software ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def post(self, request: Request, conn: Connection, name: str) -> Response: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') + + data = await self.get_data(['name'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + ban = conn.update_software_ban(**data) + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_software_ban(domain): + return Response.new_error(404, 'Software not banned', 'json') + + conn.del_software_ban(domain) + return Response.new({'message': 'Unbanned software'}, ctype = 'json') + + +@register_route('/api/v1/whitelist') +class Whitelist(View): + async def get(self, request: Request, conn: Connection) -> Response: + items = conn.execute('SELECT * FROM whitelist').all() + return Response.new(items, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain']) + + if isinstance(data, Response): + return data + + if conn.get_domain_whitelist(data['domain']): + return Response.new_error(400, 'Domain already added to whitelist', 'json') + + item = conn.put_domain_whitelist(**data) + return Response.new(item, ctype = 'json') + + +@register_route('/api/v1/domain/{domain}') +class WhitelistSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (item := conn.get_domain_whitelist(domain)): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + return Response.new(item, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_whitelist(domain): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + conn.del_domain_whitelist(domain) + return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 95b6562..53d6fd5 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -6,6 +6,9 @@ from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import HTTPMethodNotAllowed from functools import cached_property +from json.decoder import JSONDecodeError + +from ..misc import Response if typing.TYPE_CHECKING: from collections.abc import Callable, Coroutine, Generator @@ -14,7 +17,6 @@ if typing.TYPE_CHECKING: from ..cache import Cache from ..config import Config from ..http_client import HttpClient - from ..misc import Response VIEWS = [] @@ -91,3 +93,32 @@ class View(AbstractView): @property def database(self) -> Database: return self.app.database + + + async def get_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + post_data = await self.request.post() + + elif self.request.content_type == 'application/json': + try: + post_data = await self.request.json() + + except JSONDecodeError: + return Response.new_error(400, 'Invalid JSON data', 'json') + + else: + post_data = self.request.query + + data = {} + + try: + for key in required: + data[key] = post_data[key] + + except KeyError as e: + return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + + for key in optional: + data[key] = post_data.get(key) + + return data diff --git a/requirements.txt b/requirements.txt index 5239cb8..e7c73fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp>=3.9.1 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz +argon2-cffi==23.1.0 click>=8.1.2 gunicorn==21.1.0 hiredis==2.3.2 From bd4790212e5b5e120f3a0323c34182961fc35caa Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 10 Feb 2024 16:14:44 -0500 Subject: [PATCH 3/7] changes to api and update tinysql * rename `View.get_data` to `View.get_api_data` * normally aquire database connection on request * rename `/api/v1/inbox` to `/api/v1/instance` * rework `POST /api/v1/instance` * add `Connection.transaction` calls * add `PATCH /api/v1/instance/{domain}` --- relay/views/api.py | 132 ++++++++++++++++++++++++++++++++------------ relay/views/base.py | 7 +-- requirements.txt | 2 +- 3 files changed, 101 insertions(+), 40 deletions(-) diff --git a/relay/views/api.py b/relay/views/api.py index f75caf4..2fcbc6c 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -5,12 +5,14 @@ import typing from aiohttp import web from argon2.exceptions import VerifyMismatchError from datetime import datetime, timezone +from urllib.parse import urlparse from .base import View, register_route from .. import __version__ +from .. import logger as logging from ..database.config import CONFIG_DEFAULTS -from ..misc import Response +from ..misc import Message, Response if typing.TYPE_CHECKING: from aiohttp.web import Request @@ -27,14 +29,14 @@ CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( ('GET', '/api/v1/relay'), + ('GET', '/api/v1/instance'), ('POST', '/api/v1/token') ) def check_api_path(method: str, path: str) -> bool: - for m, p in PUBLIC_API_PATHS: - if m == method and p == path: - return False + if (method, path) in PUBLIC_API_PATHS: + return False return path.startswith('/api') @@ -70,7 +72,7 @@ class Login(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['username', 'password'], []) + data = await self.get_api_data(['username', 'password'], []) if isinstance(data, Response): return data @@ -84,12 +86,16 @@ class Login(View): except VerifyMismatchError: return Response.new_error(401, 'Invalid password', 'json') - token = conn.put_token(data['username']) + with conn.transaction(): + token = conn.put_token(data['username']) + return Response.new({'token': token['code']}, ctype = 'json') async def delete(self, request: Request, conn: Connection) -> Response: - conn.del_token(request['token']) + with conn.transaction(): + conn.del_token(request['token']) + return Response.new({'message': 'Token revoked'}, ctype = 'json') @@ -127,7 +133,7 @@ class Config(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['key', 'value'], []) + data = await self.get_api_data(['key', 'value'], []) if isinstance(data, Response): return data @@ -135,12 +141,14 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - conn.put_config(data['key'], data['value']) + with conn.transaction(): + conn.put_config(data['key'], data['value']) + return Response.new({'message': 'Updated config'}, ctype = 'json') async def delete(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['key'], []) + data = await self.get_api_data(['key'], []) if isinstance(data, Response): return data @@ -148,11 +156,13 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + with conn.transaction(): + conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + return Response.new({'message': 'Updated config'}, ctype = 'json') -@register_route('/api/v1/inbox') +@register_route('/api/v1/instance') class Inbox(View): async def get(self, request: Request, conn: Connection) -> Response: data = [] @@ -171,34 +181,72 @@ class Inbox(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain', 'inbox', 'actor'], ['software']) + data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) if isinstance(data, Response): return data + data['domain'] = urlparse(data["actor"]).netloc + if conn.get_inbox(data['domain']): - return Response.new_error(404, 'Inbox already in database', 'json') + return Response.new_error(404, 'Instance already in database', 'json') - row = conn.put_inbox(**data) - return Response.new(row.to_json(), ctype = 'json') + if not data.get('inbox'): + try: + actor_data = await self.client.get( + data['actor'], + sign_headers = True, + loads = Message.parse + ) + + data['inbox'] = actor_data.shared_inbox + + except Exception as e: + logging.error('Failed to fetch actor: %s', str(e)) + return Response.new_error(500, 'Failed to fetch actor', 'json') + + with conn.transaction(): + row = conn.put_inbox(**data) + + return Response.new(row, ctype = 'json') -@register_route('/api/v1/inbox/{domain}') +@register_route('/api/v1/instance/{domain}') class InboxSingle(View): async def get(self, request: Request, conn: Connection, domain: str) -> Response: if not (row := conn.get_inbox(domain)): - return Response.new_error(404, 'Inbox with domain not found', 'json') + return Response.new_error(404, 'Instance with domain not found', 'json') row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() return Response.new(row, ctype = 'json') + async def patch(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') + + data = await self.get_api_data([], ['actor', 'software', 'followid']) + + if isinstance(data, Response): + return data + + if not (instance := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') + + with conn.transaction(): + instance = conn.update_inbox(instance['inbox'], **data) + + return Response.new(instance, ctype = 'json') + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: if not conn.get_inbox(domain): - return Response.new_error(404, 'Inbox with domain not found', 'json') + return Response.new_error(404, 'Instance with domain not found', 'json') - conn.del_inbox(domain) - return Response.new({'message': 'Deleted inbox'}, ctype = 'json') + with conn.transaction(): + conn.del_inbox(domain) + + return Response.new({'message': 'Deleted instance'}, ctype = 'json') @register_route('/api/v1/domain_ban') @@ -209,7 +257,7 @@ class DomainBan(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain'], ['note', 'reason']) + data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data @@ -217,7 +265,9 @@ class DomainBan(View): if conn.get_domain_ban(data['domain']): return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_domain_ban(**data) + with conn.transaction(): + ban = conn.put_domain_ban(**data) + return Response.new(ban, ctype = 'json') @@ -234,12 +284,14 @@ class DomainBanSingle(View): if not conn.get_domain_ban(domain): return Response.new_error(404, 'Domain not banned', 'json') - data = await self.get_data(['domain'], ['note', 'reason']) + data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data - ban = conn.update_domain_ban(**data) + with conn.transaction(): + ban = conn.update_domain_ban(**data) + return Response.new(ban, ctype = 'json') @@ -247,7 +299,9 @@ class DomainBanSingle(View): if not conn.get_domain_ban(domain): return Response.new_error(404, 'Domain not banned', 'json') - conn.del_domain_ban(domain) + with conn.transaction(): + conn.del_domain_ban(domain) + return Response.new({'message': 'Unbanned domain'}, ctype = 'json') @@ -259,7 +313,7 @@ class SoftwareBan(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['name'], ['note', 'reason']) + data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data @@ -267,7 +321,9 @@ class SoftwareBan(View): if conn.get_software_ban(data['name']): return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_software_ban(**data) + with conn.transaction(): + ban = conn.put_software_ban(**data) + return Response.new(ban, ctype = 'json') @@ -284,12 +340,14 @@ class SoftwareBanSingle(View): if not conn.get_software_ban(name): return Response.new_error(404, 'Software not banned', 'json') - data = await self.get_data(['name'], ['note', 'reason']) + data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data - ban = conn.update_software_ban(**data) + with conn.transaction(): + ban = conn.update_software_ban(**data) + return Response.new(ban, ctype = 'json') @@ -297,7 +355,9 @@ class SoftwareBanSingle(View): if not conn.get_software_ban(domain): return Response.new_error(404, 'Software not banned', 'json') - conn.del_software_ban(domain) + with conn.transaction(): + conn.del_software_ban(domain) + return Response.new({'message': 'Unbanned software'}, ctype = 'json') @@ -309,7 +369,7 @@ class Whitelist(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain']) + data = await self.get_api_data(['domain']) if isinstance(data, Response): return data @@ -317,7 +377,9 @@ class Whitelist(View): if conn.get_domain_whitelist(data['domain']): return Response.new_error(400, 'Domain already added to whitelist', 'json') - item = conn.put_domain_whitelist(**data) + with conn.transaction(): + item = conn.put_domain_whitelist(**data) + return Response.new(item, ctype = 'json') @@ -334,5 +396,7 @@ class WhitelistSingle(View): if not conn.get_domain_whitelist(domain): return Response.new_error(404, 'Domain not in whitelist', 'json') - conn.del_domain_whitelist(domain) + with conn.transaction(): + conn.del_domain_whitelist(domain) + return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 53d6fd5..093fcb7 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -43,10 +43,7 @@ class View(AbstractView): async def _run_handler(self, handler: Coroutine) -> Response: - with self.database.config.connection_class(self.database) as conn: - # todo: remove on next tinysql release - conn.open() - + with self.database.connection(False) as conn: return await handler(self.request, conn, **self.request.match_info) @@ -95,7 +92,7 @@ class View(AbstractView): return self.app.database - async def get_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + async def get_api_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: post_data = await self.request.post() diff --git a/requirements.txt b/requirements.txt index e7c73fa..16bd9dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,6 @@ gunicorn==21.1.0 hiredis==2.3.2 pyyaml>=6.0 redis==5.0.1 -tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz +tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/f8db814084dded0a46bd3a9576e09fca860f2166.tar.gz importlib_resources==6.1.1;python_version<'3.9' From 101c668173844e55b7ccdd8e24815bbec2938bad Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sun, 11 Feb 2024 15:59:59 -0500 Subject: [PATCH 4/7] fix a few api endpoints --- relay/views/api.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/relay/views/api.py b/relay/views/api.py index 2fcbc6c..56b5ed7 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -284,13 +284,16 @@ class DomainBanSingle(View): if not conn.get_domain_ban(domain): return Response.new_error(404, 'Domain not banned', 'json') - data = await self.get_api_data(['domain'], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) if isinstance(data, Response): return data + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + with conn.transaction(): - ban = conn.update_domain_ban(**data) + ban = conn.update_domain_ban(domain, **data) return Response.new(ban, ctype = 'json') @@ -336,27 +339,30 @@ class SoftwareBanSingle(View): return Response.new(ban, ctype = 'json') - async def post(self, request: Request, conn: Connection, name: str) -> Response: + async def patch(self, request: Request, conn: Connection, name: str) -> Response: if not conn.get_software_ban(name): return Response.new_error(404, 'Software not banned', 'json') - data = await self.get_api_data(['name'], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) if isinstance(data, Response): return data + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + with conn.transaction(): - ban = conn.update_software_ban(**data) + ban = conn.update_software_ban(name, **data) return Response.new(ban, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_software_ban(domain): + async def delete(self, request: Request, conn: Connection, name: str) -> Response: + if not conn.get_software_ban(name): return Response.new_error(404, 'Software not banned', 'json') with conn.transaction(): - conn.del_software_ban(domain) + conn.del_software_ban(name) return Response.new({'message': 'Unbanned software'}, ctype = 'json') From a644900417a089e9aab62a652c36b5f71ccc9ed9 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 14 Feb 2024 12:29:56 -0500 Subject: [PATCH 5/7] add a semi-colon to the end of all statements --- relay/data/statements.sql | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 0f98f64..2ddef35 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -76,7 +76,7 @@ WHERE code = :code; -- name: del-token-user DELETE FROM tokens -WHERE user = :username +WHERE user = :username; -- name: get-software-ban @@ -91,37 +91,37 @@ RETURNING *; -- name: del-software-ban DELETE FROM software_bans -WHERE name = :name +WHERE name = :name; -- name: get-domain-ban -SELECT * FROM domain_bans WHERE domain = :domain +SELECT * FROM domain_bans WHERE domain = :domain; -- name: put-domain-ban INSERT INTO domain_bans (domain, reason, note, created) VALUES (:domain, :reason, :note, :created) -RETURNING * +RETURNING *; -- name: del-domain-ban DELETE FROM domain_bans -WHERE domain = :domain +WHERE domain = :domain; -- name: get-domain-whitelist -SELECT * FROM whitelist WHERE domain = :domain +SELECT * FROM whitelist WHERE domain = :domain; -- name: put-domain-whitelist INSERT INTO whitelist (domain, created) VALUES (:domain, :created) -RETURNING * +RETURNING *; -- name: del-domain-whitelist DELETE FROM whitelist -WHERE domain = :domain +WHERE domain = :domain; -- cache functions -- @@ -135,7 +135,7 @@ CREATE TABLE IF NOT EXISTS cache ( type TEXT DEFAULT 'str', updated TIMESTAMP NOT NULL, UNIQUE(namespace, key) -) +); -- name: create-cache-table-postgres CREATE TABLE IF NOT EXISTS cache ( @@ -146,21 +146,21 @@ CREATE TABLE IF NOT EXISTS cache ( type TEXT DEFAULT 'str', updated TIMESTAMP NOT NULL, UNIQUE(namespace, key) -) +); -- name: get-cache-item SELECT * FROM cache -WHERE namespace = :namespace and key = :key +WHERE namespace = :namespace and key = :key; -- name: get-cache-keys SELECT key FROM cache -WHERE namespace = :namespace +WHERE namespace = :namespace; -- name: get-cache-namespaces -SELECT DISTINCT namespace FROM cache +SELECT DISTINCT namespace FROM cache; -- name: set-cache-item @@ -168,18 +168,18 @@ INSERT INTO cache (namespace, key, value, type, updated) VALUES (:namespace, :key, :value, :type, :date) ON CONFLICT (namespace, key) DO UPDATE SET value = :value, type = :type, updated = :date -RETURNING * +RETURNING *; -- name: del-cache-item DELETE FROM cache -WHERE namespace = :namespace and key = :key +WHERE namespace = :namespace and key = :key; -- name: del-cache-namespace DELETE FROM cache -WHERE namespace = :namespace +WHERE namespace = :namespace; -- name: del-cache-all -DELETE FROM cache +DELETE FROM cache; From e4bcbdeccb7f1968c78b938e5922f9f99980fc76 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 14 Feb 2024 14:17:53 -0500 Subject: [PATCH 6/7] don't get a database connection at the start of every request --- relay/processors.py | 10 +- relay/views/activitypub.py | 43 +++---- relay/views/api.py | 249 +++++++++++++++++++------------------ relay/views/base.py | 8 +- relay/views/frontend.py | 7 +- relay/views/misc.py | 24 ++-- 6 files changed, 180 insertions(+), 161 deletions(-) diff --git a/relay/processors.py b/relay/processors.py index d9780d1..7fc3423 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -170,7 +170,7 @@ processors = { } -async def run_processor(view: ActorView, conn: Connection) -> None: +async def run_processor(view: ActorView) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -180,8 +180,8 @@ async def run_processor(view: ActorView, conn: Connection) -> None: return - if view.instance: - with conn.transaction(): + with view.database.connection(False) as conn: + if view.instance: if not view.instance['software']: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): view.instance = conn.update_inbox( @@ -195,5 +195,5 @@ async def run_processor(view: ActorView, conn: Connection) -> None: actor = view.actor.id ) - logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - await processors[view.message.type](view, conn) + logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) + await processors[view.message.type](view, conn) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index be51047..70f759a 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -34,7 +34,7 @@ class ActorView(View): self.signer: Signer = None - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = Message.new_actor( host = self.config.domain, pubkey = self.app.signer.pubkey @@ -43,35 +43,36 @@ class ActorView(View): return Response.new(data, ctype='activity') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: if response := await self.get_post_data(): return response - self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() + with self.database.connection(False) as conn: + self.instance = conn.get_inbox(self.actor.shared_inbox) + config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if actor is banned - if conn.get_domain_ban(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.instance: - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - return Response.new_error(401, 'access denied', 'json') + return Response.new_error(401, 'access denied', 'json') logging.debug('>> payload %s', self.message.to_json(4)) - await run_processor(self, conn) + await run_processor(self) return Response.new(status = 202) @@ -162,7 +163,7 @@ class ActorView(View): @register_route('/.well-known/webfinger') class WebfingerView(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: try: subject = request.query['resource'] diff --git a/relay/views/api.py b/relay/views/api.py index 56b5ed7..ade1d3b 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -67,33 +67,33 @@ async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Respo @register_route('/api/v1/token') class Login(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: return Response.new({'message': 'Token valid :3'}) - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['username', 'password'], []) if isinstance(data, Response): return data - if not (user := conn.get_user(data['username'])): - return Response.new_error(401, 'User not found', 'json') + with self.database.connction(True) as conn: + if not (user := conn.get_user(data['username'])): + return Response.new_error(401, 'User not found', 'json') - try: - conn.hasher.verify(user['hash'], data['password']) + try: + conn.hasher.verify(user['hash'], data['password']) - except VerifyMismatchError: - return Response.new_error(401, 'Invalid password', 'json') + except VerifyMismatchError: + return Response.new_error(401, 'Invalid password', 'json') - with conn.transaction(): token = conn.put_token(data['username']) return Response.new({'token': token['code']}, ctype = 'json') - async def delete(self, request: Request, conn: Connection) -> Response: - with conn.transaction(): + async def delete(self, request: Request) -> Response: + with self.database.connection(True) as conn: conn.del_token(request['token']) return Response.new({'message': 'Token revoked'}, ctype = 'json') @@ -101,9 +101,10 @@ class Login(View): @register_route('/api/v1/relay') class RelayInfo(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] data = { 'domain': self.config.domain, @@ -122,9 +123,10 @@ class RelayInfo(View): @register_route('/api/v1/config') class Config(View): - async def get(self, request: Request, conn: Connection) -> Response: - data = conn.get_config_all() - data['log-level'] = data['log-level'].name + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + data = conn.get_config_all() + data['log-level'] = data['log-level'].name for key in CONFIG_IGNORE: del data[key] @@ -132,7 +134,7 @@ class Config(View): return Response.new(data, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['key', 'value'], []) if isinstance(data, Response): @@ -141,13 +143,13 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with conn.transaction(): + with self.database.connection(True) as conn: conn.put_config(data['key'], data['value']) return Response.new({'message': 'Updated config'}, ctype = 'json') - async def delete(self, request: Request, conn: Connection) -> Response: + async def delete(self, request: Request) -> Response: data = await self.get_api_data(['key'], []) if isinstance(data, Response): @@ -156,7 +158,7 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with conn.transaction(): + with self.database.connection(True) as conn: conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -164,23 +166,24 @@ class Config(View): @register_route('/api/v1/instance') class Inbox(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = [] - for inbox in conn.execute('SELECT * FROM inboxes'): - try: - created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) + with self.database.connection(False) as conn: + for inbox in conn.execute('SELECT * FROM inboxes'): + try: + created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) - except TypeError: - created = datetime.fromisoformat(inbox['created']) + except TypeError: + created = datetime.fromisoformat(inbox['created']) - inbox['created'] = created.isoformat() - data.append(inbox) + inbox['created'] = created.isoformat() + data.append(inbox) return Response.new(data, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) if isinstance(data, Response): @@ -188,24 +191,24 @@ class Inbox(View): data['domain'] = urlparse(data["actor"]).netloc - if conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance already in database', 'json') + with self.database.connection(True) as conn: + if conn.get_inbox(data['domain']): + return Response.new_error(404, 'Instance already in database', 'json') - if not data.get('inbox'): - try: - actor_data = await self.client.get( - data['actor'], - sign_headers = True, - loads = Message.parse - ) + if not data.get('inbox'): + try: + actor_data = await self.client.get( + data['actor'], + sign_headers = True, + loads = Message.parse + ) - data['inbox'] = actor_data.shared_inbox + data['inbox'] = actor_data.shared_inbox - except Exception as e: - logging.error('Failed to fetch actor: %s', str(e)) - return Response.new_error(500, 'Failed to fetch actor', 'json') + except Exception as e: + logging.error('Failed to fetch actor: %s', str(e)) + return Response.new_error(500, 'Failed to fetch actor', 'json') - with conn.transaction(): row = conn.put_inbox(**data) return Response.new(row, ctype = 'json') @@ -213,37 +216,38 @@ class Inbox(View): @register_route('/api/v1/instance/{domain}') class InboxSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (row := conn.get_inbox(domain)): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (row := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() return Response.new(row, ctype = 'json') - async def patch(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_inbox(domain): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') - data = await self.get_api_data([], ['actor', 'software', 'followid']) + data = await self.get_api_data([], ['actor', 'software', 'followid']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not (instance := conn.get_inbox(domain)): - return Response.new_error(404, 'Instance with domain not found', 'json') + if not (instance := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') - with conn.transaction(): instance = conn.update_inbox(instance['inbox'], **data) return Response.new(instance, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_inbox(domain): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') - with conn.transaction(): conn.del_inbox(domain) return Response.new({'message': 'Deleted instance'}, ctype = 'json') @@ -251,21 +255,23 @@ class InboxSingle(View): @register_route('/api/v1/domain_ban') class DomainBan(View): - async def get(self, request: Request, conn: Connection) -> Response: - bans = conn.execute('SELECT * FROM domain_bans').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM domain_bans').all() + return Response.new(bans, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data - if conn.get_domain_ban(data['domain']): - return Response.new_error(400, 'Domain already banned', 'json') + with self.database.connection(True) as conn: + if conn.get_domain_ban(data['domain']): + return Response.new_error(400, 'Domain already banned', 'json') - with conn.transaction(): ban = conn.put_domain_ban(**data) return Response.new(ban, ctype = 'json') @@ -273,36 +279,37 @@ class DomainBan(View): @register_route('/api/v1/domain_ban/{domain}') class DomainBanSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (ban := conn.get_domain_ban(domain)): - return Response.new_error(404, 'Domain ban not found', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_domain_ban(domain)): + return Response.new_error(404, 'Domain ban not found', 'json') return Response.new(ban, ctype = 'json') - async def patch(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_ban(domain): - return Response.new_error(404, 'Domain not banned', 'json') + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') - data = await self.get_api_data([], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - with conn.transaction(): ban = conn.update_domain_ban(domain, **data) return Response.new(ban, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_ban(domain): - return Response.new_error(404, 'Domain not banned', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') - with conn.transaction(): conn.del_domain_ban(domain) return Response.new({'message': 'Unbanned domain'}, ctype = 'json') @@ -310,21 +317,23 @@ class DomainBanSingle(View): @register_route('/api/v1/software_ban') class SoftwareBan(View): - async def get(self, request: Request, conn: Connection) -> Response: - bans = conn.execute('SELECT * FROM software_bans').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM software_bans').all() + return Response.new(bans, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data - if conn.get_software_ban(data['name']): - return Response.new_error(400, 'Domain already banned', 'json') + with self.database.connection(True) as conn: + if conn.get_software_ban(data['name']): + return Response.new_error(400, 'Domain already banned', 'json') - with conn.transaction(): ban = conn.put_software_ban(**data) return Response.new(ban, ctype = 'json') @@ -332,36 +341,37 @@ class SoftwareBan(View): @register_route('/api/v1/software_ban/{name}') class SoftwareBanSingle(View): - async def get(self, request: Request, conn: Connection, name: str) -> Response: - if not (ban := conn.get_software_ban(name)): - return Response.new_error(404, 'Software ban not found', 'json') + async def get(self, request: Request, name: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_software_ban(name)): + return Response.new_error(404, 'Software ban not found', 'json') return Response.new(ban, ctype = 'json') - async def patch(self, request: Request, conn: Connection, name: str) -> Response: - if not conn.get_software_ban(name): - return Response.new_error(404, 'Software not banned', 'json') + async def patch(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') - data = await self.get_api_data([], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - with conn.transaction(): ban = conn.update_software_ban(name, **data) return Response.new(ban, ctype = 'json') - async def delete(self, request: Request, conn: Connection, name: str) -> Response: - if not conn.get_software_ban(name): - return Response.new_error(404, 'Software not banned', 'json') + async def delete(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') - with conn.transaction(): conn.del_software_ban(name) return Response.new({'message': 'Unbanned software'}, ctype = 'json') @@ -369,21 +379,23 @@ class SoftwareBanSingle(View): @register_route('/api/v1/whitelist') class Whitelist(View): - async def get(self, request: Request, conn: Connection) -> Response: - items = conn.execute('SELECT * FROM whitelist').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + items = conn.execute('SELECT * FROM whitelist').all() + return Response.new(items, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_api_data(['domain']) + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['domain'], []) if isinstance(data, Response): return data - if conn.get_domain_whitelist(data['domain']): - return Response.new_error(400, 'Domain already added to whitelist', 'json') + with self.database.connection(True) as conn: + if conn.get_domain_whitelist(data['domain']): + return Response.new_error(400, 'Domain already added to whitelist', 'json') - with conn.transaction(): item = conn.put_domain_whitelist(**data) return Response.new(item, ctype = 'json') @@ -391,18 +403,19 @@ class Whitelist(View): @register_route('/api/v1/domain/{domain}') class WhitelistSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (item := conn.get_domain_whitelist(domain)): - return Response.new_error(404, 'Domain not in whitelist', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (item := conn.get_domain_whitelist(domain)): + return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new(item, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_whitelist(domain): - return Response.new_error(404, 'Domain not in whitelist', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not conn.get_domain_whitelist(domain): + return Response.new_error(404, 'Domain not in whitelist', 'json') - with conn.transaction(): conn.del_domain_whitelist(domain) return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 093fcb7..ce72e4b 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -43,8 +43,7 @@ class View(AbstractView): async def _run_handler(self, handler: Coroutine) -> Response: - with self.database.connection(False) as conn: - return await handler(self.request, conn, **self.request.match_info) + return await handler(self.request, **self.request.match_info) @cached_property @@ -92,7 +91,10 @@ class View(AbstractView): return self.app.database - async def get_api_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + async def get_api_data(self, + required: list[str], + optional: list[str]) -> dict[str, str] | Response: + if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: post_data = await self.request.post() diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 987b9b0..663edd4 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -48,9 +48,10 @@ HOME_TEMPLATE = """ @register_route('/') class HomeView(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() text = HOME_TEMPLATE.format( host = self.config.domain, diff --git a/relay/views/misc.py b/relay/views/misc.py index e41ae2b..7c2e65c 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -33,17 +33,18 @@ if Path(__file__).parent.parent.joinpath('.git').exists(): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): # pylint: disable=no-self-use - async def get(self, request: Request, conn: Connection, niversion: str) -> Response: - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request, niversion: str) -> Response: + with self.database.connection(False) as conn: + inboxes = conn.execute('SELECT * FROM inboxes').all() - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay' @@ -53,6 +54,7 @@ class NodeinfoView(View): @register_route('/.well-known/nodeinfo') class WellknownNodeinfoView(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = WellKnownNodeinfo.new_template(self.config.domain) + return Response.new(data, ctype = 'json') From c2b88b6dd8b563a1ca803660404d6f241519d836 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 14 Feb 2024 14:24:16 -0500 Subject: [PATCH 7/7] fix linter warnings --- relay/manage.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/relay/manage.py b/relay/manage.py index 039a6f5..e8aab1b 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -400,15 +400,11 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: return while True: - password = click.prompt('New password', hide_input = True) - - if not password: + if not (password := click.prompt('New password', hide_input = True)): click.echo('No password provided') continue - password2 = click.prompt('New password again', hide_input = True) - - if password != password2: + if password != click.prompt('New password again', hide_input = True): click.echo('Passwords do not match') continue @@ -471,7 +467,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None: 'Delete an API token' with ctx.obj.database.connection() as conn: - if not (conn.get_token(code)): + if not conn.get_token(code): click.echo('Token does not exist') return