From 2fcaea85aead0c2cdf19ba20270c461538706567 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sun, 4 Feb 2024 04:53:39 -0500 Subject: [PATCH] create a new database connection for each request --- relay/manage.py | 2 +- relay/misc.py | 13 +++- relay/processors.py | 143 ++++++++++++++++++++++---------------------- relay/views.py | 92 ++++++++++++++-------------- 4 files changed, 128 insertions(+), 122 deletions(-) diff --git a/relay/manage.py b/relay/manage.py index 6872026..df5b4cb 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -53,7 +53,7 @@ SOFTWARE = ( def check_alphanumeric(text: str) -> str: if not text.isalnum(): - raise click.BadParameter(f'String not alphanumeric') + raise click.BadParameter('String not alphanumeric') return text diff --git a/relay/misc.py b/relay/misc.py index 831db38..94a7182 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -15,7 +15,8 @@ from uuid import uuid4 if typing.TYPE_CHECKING: from collections.abc import Coroutine, Generator - from typing import Any + from tinysql import Connection + from typing import Any, Awaitable from .application import Application from .cache import Cache from .config import Config @@ -234,6 +235,9 @@ class Response(AiohttpResponse): class View(AbstractView): + conn: Connection + + def __await__(self) -> Generator[Response]: if (self.request.method) not in METHODS: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) @@ -241,7 +245,12 @@ class View(AbstractView): if not (handler := self.handlers.get(self.request.method)): raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None - return handler(self.request, **self.request.match_info).__await__() + return self._run_handler(handler).__await__() + + + async def _run_handler(self, handler: Awaitable) -> Response: + with self.database.config.connection_class(self.database) as conn: + return await handler(self.request,conn, **self.request.match_info) @cached_property diff --git a/relay/processors.py b/relay/processors.py index 36ece44..d9780d1 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,9 +1,9 @@ from __future__ import annotations -import tinysql import typing from . import logger as logging +from .database.connection import Connection from .misc import Message if typing.TYPE_CHECKING: @@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool: return False -async def handle_relay(view: ActorView) -> None: +async def handle_relay(view: ActorView, conn: Connection) -> None: try: view.cache.get('handle-relay', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id) @@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None: pass message = Message.new_announce(view.config.domain, view.message.object_id) - view.cache.set('handle-relay', view.message.object_id, message.id, 'str') logging.debug('>> relay: %s', message) - with view.database.connection() as conn: - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message, view.instance) + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message, view.instance) + + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') -async def handle_forward(view: ActorView) -> None: +async def handle_forward(view: ActorView, conn: Connection) -> None: try: view.cache.get('handle-relay', view.message.object_id) logging.verbose('already forwarded %s', view.message.object_id) @@ -51,58 +51,58 @@ async def handle_forward(view: ActorView) -> None: pass message = Message.new_announce(view.config.domain, view.message) - view.cache.set('handle-relay', view.message.object_id, message.id, 'str') logging.debug('>> forward: %s', message) - with view.database.connection() as conn: - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message, view.instance) + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message, view.instance) + + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') -async def handle_follow(view: ActorView) -> None: +async def handle_follow(view: ActorView, conn: Connection) -> None: nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) software = nodeinfo.sw_name if nodeinfo else None - with view.database.connection() as conn: - # reject if software used by actor is banned - if conn.get_software_ban(software): - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = False - ) + # reject if software used by actor is banned + if conn.get_software_ban(software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False ) + ) - logging.verbose( - 'Rejected follow from actor for using specific software: actor=%s, software=%s', - view.actor.id, - software + logging.verbose( + 'Rejected follow from actor for using specific software: actor=%s, software=%s', + view.actor.id, + software + ) + + return + + ## reject if the actor is not an instance actor + if person_check(view.actor, software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False ) + ) - return + logging.verbose('Non-application actor tried to follow: %s', view.actor.id) + return - ## reject if the actor is not an instance actor - if person_check(view.actor, software): - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = False - ) - ) + if conn.get_inbox(view.actor.shared_inbox): + view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id) - logging.verbose('Non-application actor tried to follow: %s', view.actor.id) - return - - if conn.get_inbox(view.actor.shared_inbox): - view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id) - - else: + else: + with conn.transaction(): view.instance = conn.put_inbox( view.actor.domain, view.actor.shared_inbox, @@ -111,37 +111,37 @@ async def handle_follow(view: ActorView) -> None: software ) + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = True + ), + view.instance + ) + + # Are Akkoma and Pleroma the only two that expect a follow back? + # Ignoring only Mastodon for now + if software != 'mastodon': view.app.push_message( view.actor.shared_inbox, - Message.new_response( + Message.new_follow( host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = True + actor = view.actor.id ), view.instance ) - # Are Akkoma and Pleroma the only two that expect a follow back? - # Ignoring only Mastodon for now - if software != 'mastodon': - view.app.push_message( - view.actor.shared_inbox, - Message.new_follow( - host = view.config.domain, - actor = view.actor.id - ), - view.instance - ) - -async def handle_undo(view: ActorView) -> None: +async def handle_undo(view: ActorView, conn: Connection) -> None: ## If the object is not a Follow, forward it if view.message.object['type'] != 'Follow': - await handle_forward(view) + await handle_forward(view, conn) return - with view.database.connection() as conn: + with conn.transaction(): if not conn.del_inbox(view.actor.id): logging.verbose( 'Failed to delete "%s" with follow ID "%s"', @@ -170,7 +170,7 @@ processors = { } -async def run_processor(view: ActorView) -> None: +async def run_processor(view: ActorView, conn: Connection) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None: return if view.instance: - if not view.instance['software']: - if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): - with view.database.connection() as conn: + with conn.transaction(): + if not view.instance['software']: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): view.instance = conn.update_inbox( view.instance['inbox'], software = nodeinfo.sw_name ) - if not view.instance['actor']: - with view.database.connection() as conn: + if not view.instance['actor']: view.instance = conn.update_inbox( view.instance['inbox'], actor = view.actor.id ) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - await processors[view.message.type](view) + await processors[view.message.type](view, conn) diff --git a/relay/views.py b/relay/views.py index 64c1d57..cb648a2 100644 --- a/relay/views.py +++ b/relay/views.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import subprocess import traceback import typing @@ -12,6 +11,7 @@ from pathlib import Path from . import __version__ from . import logger as logging +from .database.connection import Connection from .misc import Message, Response, View from .processors import run_processor @@ -75,17 +75,16 @@ def register_route(*paths: str) -> Callable: @register_route('/') class HomeView(View): - async def get(self, request: Request) -> Response: - with self.database.connection() as conn: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request, conn: Connection) -> Response: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() - text = HOME_TEMPLATE.format( - host = self.config.domain, - note = config['note'], - count = len(inboxes), - targets = '
'.join(inbox['domain'] for inbox in inboxes) - ) + text = HOME_TEMPLATE.format( + host = self.config.domain, + note = config['note'], + count = len(inboxes), + targets = '
'.join(inbox['domain'] for inbox in inboxes) + ) return Response.new(text, ctype='html') @@ -103,7 +102,7 @@ class ActorView(View): self.signer: Signer = None - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: data = Message.new_actor( host = self.config.domain, pubkey = self.app.signer.pubkey @@ -112,37 +111,36 @@ class ActorView(View): return Response.new(data, ctype='activity') - async def post(self, request: Request) -> Response: + async def post(self, request: Request, conn: Connection) -> Response: if response := await self.get_post_data(): return response - with self.database.connection() as conn: - self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() + self.instance = conn.get_inbox(self.actor.shared_inbox) + config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if actor is banned - if conn.get_domain_ban(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.instance: - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - return Response.new_error(401, 'access denied', 'json') + return Response.new_error(401, 'access denied', 'json') - logging.debug('>> payload %s', self.message.to_json(4)) + logging.debug('>> payload %s', self.message.to_json(4)) - asyncio.ensure_future(run_processor(self)) - return Response.new(status = 202) + await run_processor(self, conn) + return Response.new(status = 202) async def get_post_data(self) -> Response | None: @@ -232,7 +230,7 @@ class ActorView(View): @register_route('/.well-known/webfinger') class WebfingerView(View): - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: try: subject = request.query['resource'] @@ -253,18 +251,18 @@ class WebfingerView(View): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): - async def get(self, request: Request, niversion: str) -> Response: - with self.database.connection() as conn: - inboxes = conn.execute('SELECT * FROM inboxes').all() + # pylint: disable=no-self-use + async def get(self, request: Request, conn: Connection, niversion: str) -> Response: + inboxes = conn.execute('SELECT * FROM inboxes').all() - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay' @@ -274,6 +272,6 @@ class NodeinfoView(View): @register_route('/.well-known/nodeinfo') class WellknownNodeinfoView(View): - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: data = WellKnownNodeinfo.new_template(self.config.domain) return Response.new(data, ctype = 'json')