diff --git a/relay/manage.py b/relay/manage.py
index 6872026..df5b4cb 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -53,7 +53,7 @@ SOFTWARE = (
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
- raise click.BadParameter(f'String not alphanumeric')
+ raise click.BadParameter('String not alphanumeric')
return text
diff --git a/relay/misc.py b/relay/misc.py
index 831db38..94a7182 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -15,7 +15,8 @@ from uuid import uuid4
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
- from typing import Any
+ from tinysql import Connection
+ from typing import Any, Awaitable
from .application import Application
from .cache import Cache
from .config import Config
@@ -234,6 +235,9 @@ class Response(AiohttpResponse):
class View(AbstractView):
+ conn: Connection
+
+
def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@@ -241,7 +245,12 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
- return handler(self.request, **self.request.match_info).__await__()
+ return self._run_handler(handler).__await__()
+
+
+ async def _run_handler(self, handler: Awaitable) -> Response:
+ with self.database.config.connection_class(self.database) as conn:
+ return await handler(self.request,conn, **self.request.match_info)
@cached_property
diff --git a/relay/processors.py b/relay/processors.py
index 36ece44..d9780d1 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -1,9 +1,9 @@
from __future__ import annotations
-import tinysql
import typing
from . import logger as logging
+from .database.connection import Connection
from .misc import Message
if typing.TYPE_CHECKING:
@@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
return False
-async def handle_relay(view: ActorView) -> None:
+async def handle_relay(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id)
@@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
pass
message = Message.new_announce(view.config.domain, view.message.object_id)
- view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> relay: %s', message)
- with view.database.connection() as conn:
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message, view.instance)
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message, view.instance)
+
+ view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
-async def handle_forward(view: ActorView) -> None:
+async def handle_forward(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
@@ -51,58 +51,58 @@ async def handle_forward(view: ActorView) -> None:
pass
message = Message.new_announce(view.config.domain, view.message)
- view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> forward: %s', message)
- with view.database.connection() as conn:
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message, view.instance)
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message, view.instance)
+
+ view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
-async def handle_follow(view: ActorView) -> None:
+async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
- with view.database.connection() as conn:
- # reject if software used by actor is banned
- if conn.get_software_ban(software):
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = False
- )
+ # reject if software used by actor is banned
+ if conn.get_software_ban(software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
)
+ )
- logging.verbose(
- 'Rejected follow from actor for using specific software: actor=%s, software=%s',
- view.actor.id,
- software
+ logging.verbose(
+ 'Rejected follow from actor for using specific software: actor=%s, software=%s',
+ view.actor.id,
+ software
+ )
+
+ return
+
+ ## reject if the actor is not an instance actor
+ if person_check(view.actor, software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
)
+ )
- return
+ logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
+ return
- ## reject if the actor is not an instance actor
- if person_check(view.actor, software):
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = False
- )
- )
+ if conn.get_inbox(view.actor.shared_inbox):
+ view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
- logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
- return
-
- if conn.get_inbox(view.actor.shared_inbox):
- view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
-
- else:
+ else:
+ with conn.transaction():
view.instance = conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
@@ -111,37 +111,37 @@ async def handle_follow(view: ActorView) -> None:
software
)
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = True
+ ),
+ view.instance
+ )
+
+ # Are Akkoma and Pleroma the only two that expect a follow back?
+ # Ignoring only Mastodon for now
+ if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
- Message.new_response(
+ Message.new_follow(
host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = True
+ actor = view.actor.id
),
view.instance
)
- # Are Akkoma and Pleroma the only two that expect a follow back?
- # Ignoring only Mastodon for now
- if software != 'mastodon':
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_follow(
- host = view.config.domain,
- actor = view.actor.id
- ),
- view.instance
- )
-
-async def handle_undo(view: ActorView) -> None:
+async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
- await handle_forward(view)
+ await handle_forward(view, conn)
return
- with view.database.connection() as conn:
+ with conn.transaction():
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
@@ -170,7 +170,7 @@ processors = {
}
-async def run_processor(view: ActorView) -> None:
+async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return
if view.instance:
- if not view.instance['software']:
- if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
- with view.database.connection() as conn:
+ with conn.transaction():
+ if not view.instance['software']:
+ if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
- if not view.instance['actor']:
- with view.database.connection() as conn:
+ if not view.instance['actor']:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
- await processors[view.message.type](view)
+ await processors[view.message.type](view, conn)
diff --git a/relay/views.py b/relay/views.py
index 64c1d57..cb648a2 100644
--- a/relay/views.py
+++ b/relay/views.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import asyncio
import subprocess
import traceback
import typing
@@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__
from . import logger as logging
+from .database.connection import Connection
from .misc import Message, Response, View
from .processors import run_processor
@@ -75,17 +75,16 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
- async def get(self, request: Request) -> Response:
- with self.database.connection() as conn:
- config = conn.get_config_all()
- inboxes = conn.execute('SELECT * FROM inboxes').all()
+ async def get(self, request: Request, conn: Connection) -> Response:
+ config = conn.get_config_all()
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
- text = HOME_TEMPLATE.format(
- host = self.config.domain,
- note = config['note'],
- count = len(inboxes),
- targets = '
'.join(inbox['domain'] for inbox in inboxes)
- )
+ text = HOME_TEMPLATE.format(
+ host = self.config.domain,
+ note = config['note'],
+ count = len(inboxes),
+ targets = '
'.join(inbox['domain'] for inbox in inboxes)
+ )
return Response.new(text, ctype='html')
@@ -103,7 +102,7 @@ class ActorView(View):
self.signer: Signer = None
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
@@ -112,37 +111,36 @@ class ActorView(View):
return Response.new(data, ctype='activity')
- async def post(self, request: Request) -> Response:
+ async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data():
return response
- with self.database.connection() as conn:
- self.instance = conn.get_inbox(self.actor.shared_inbox)
- config = conn.get_config_all()
+ self.instance = conn.get_inbox(self.actor.shared_inbox)
+ config = conn.get_config_all()
- ## reject if the actor isn't whitelisted while the whiltelist is enabled
- if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
- logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ ## reject if the actor isn't whitelisted while the whiltelist is enabled
+ if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
+ logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- ## reject if actor is banned
- if conn.get_domain_ban(self.actor.domain):
- logging.verbose('Ignored request from banned actor: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ ## reject if actor is banned
+ if conn.get_domain_ban(self.actor.domain):
+ logging.verbose('Ignored request from banned actor: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- ## reject if activity type isn't 'Follow' and the actor isn't following
- if self.message.type != 'Follow' and not self.instance:
- logging.verbose(
- 'Rejected actor for trying to post while not following: %s',
- self.actor.id
- )
+ ## reject if activity type isn't 'Follow' and the actor isn't following
+ if self.message.type != 'Follow' and not self.instance:
+ logging.verbose(
+ 'Rejected actor for trying to post while not following: %s',
+ self.actor.id
+ )
- return Response.new_error(401, 'access denied', 'json')
+ return Response.new_error(401, 'access denied', 'json')
- logging.debug('>> payload %s', self.message.to_json(4))
+ logging.debug('>> payload %s', self.message.to_json(4))
- asyncio.ensure_future(run_processor(self))
- return Response.new(status = 202)
+ await run_processor(self, conn)
+ return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
@@ -232,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger')
class WebfingerView(View):
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
try:
subject = request.query['resource']
@@ -253,18 +251,18 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
- async def get(self, request: Request, niversion: str) -> Response:
- with self.database.connection() as conn:
- inboxes = conn.execute('SELECT * FROM inboxes').all()
+ # pylint: disable=no-self-use
+ async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
- data = {
- 'name': 'activityrelay',
- 'version': VERSION,
- 'protocols': ['activitypub'],
- 'open_regs': not conn.get_config('whitelist-enabled'),
- 'users': 1,
- 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
- }
+ data = {
+ 'name': 'activityrelay',
+ 'version': VERSION,
+ 'protocols': ['activitypub'],
+ 'open_regs': not conn.get_config('whitelist-enabled'),
+ 'users': 1,
+ 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
+ }
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@@ -274,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')