create a new database connection for each request

This commit is contained in:
Izalia Mae 2024-02-04 04:53:39 -05:00
parent e6f30ddf64
commit 2fcaea85ae
4 changed files with 128 additions and 122 deletions

View file

@ -53,7 +53,7 @@ SOFTWARE = (
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():
raise click.BadParameter(f'String not alphanumeric') raise click.BadParameter('String not alphanumeric')
return text return text

View file

@ -15,7 +15,8 @@ from uuid import uuid4
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator from collections.abc import Coroutine, Generator
from typing import Any from tinysql import Connection
from typing import Any, Awaitable
from .application import Application from .application import Application
from .cache import Cache from .cache import Cache
from .config import Config from .config import Config
@ -234,6 +235,9 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
conn: Connection
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS: if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@ -241,7 +245,12 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Awaitable) -> Response:
with self.database.config.connection_class(self.database) as conn:
return await handler(self.request,conn, **self.request.match_info)
@cached_property @cached_property

View file

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
return False return False
async def handle_relay(view: ActorView) -> None: async def handle_relay(view: ActorView, conn: Connection) -> None:
try: try:
view.cache.get('handle-relay', view.message.object_id) view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id)
@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
pass pass
message = Message.new_announce(view.config.domain, view.message.object_id) message = Message.new_announce(view.config.domain, view.message.object_id)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
with view.database.connection() as conn: for inbox in conn.distill_inboxes(view.message):
for inbox in conn.distill_inboxes(view.message): view.app.push_message(inbox, message, view.instance)
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_forward(view: ActorView) -> None: async def handle_forward(view: ActorView, conn: Connection) -> None:
try: try:
view.cache.get('handle-relay', view.message.object_id) view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id) logging.verbose('already forwarded %s', view.message.object_id)
@ -51,58 +51,58 @@ async def handle_forward(view: ActorView) -> None:
pass pass
message = Message.new_announce(view.config.domain, view.message) message = Message.new_announce(view.config.domain, view.message)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
with view.database.connection() as conn: for inbox in conn.distill_inboxes(view.message):
for inbox in conn.distill_inboxes(view.message): view.app.push_message(inbox, message, view.instance)
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_follow(view: ActorView) -> None: async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn: # reject if software used by actor is banned
# reject if software used by actor is banned if conn.get_software_ban(software):
if conn.get_software_ban(software): view.app.push_message(
view.app.push_message( view.actor.shared_inbox,
view.actor.shared_inbox, Message.new_response(
Message.new_response( host = view.config.domain,
host = view.config.domain, actor = view.actor.id,
actor = view.actor.id, followid = view.message.id,
followid = view.message.id, accept = False
accept = False
)
) )
)
logging.verbose( logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s', 'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id, view.actor.id,
software software
)
return
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
) )
)
return logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
## reject if the actor is not an instance actor if conn.get_inbox(view.actor.shared_inbox):
if person_check(view.actor, software): view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) else:
return with conn.transaction():
if conn.get_inbox(view.actor.shared_inbox):
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
else:
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
view.actor.domain, view.actor.domain,
view.actor.shared_inbox, view.actor.shared_inbox,
@ -111,37 +111,37 @@ async def handle_follow(view: ActorView) -> None:
software software
) )
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = True
),
view.instance
)
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_response( Message.new_follow(
host = view.config.domain, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id
followid = view.message.id,
accept = True
), ),
view.instance view.instance
) )
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.domain,
actor = view.actor.id
),
view.instance
)
async def handle_undo(view: ActorView, conn: Connection) -> None:
async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if view.message.object['type'] != 'Follow':
await handle_forward(view) await handle_forward(view, conn)
return return
with view.database.connection() as conn: with conn.transaction():
if not conn.del_inbox(view.actor.id): if not conn.del_inbox(view.actor.id):
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', 'Failed to delete "%s" with follow ID "%s"',
@ -170,7 +170,7 @@ processors = {
} }
async def run_processor(view: ActorView) -> None: async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors: if view.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', 'Message type "%s" from actor cannot be handled: %s',
@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return return
if view.instance: if view.instance:
if not view.instance['software']: with conn.transaction():
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if not view.instance['software']:
with view.database.connection() as conn: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance['actor']: if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
actor = view.actor.id actor = view.actor.id
) )
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view) await processors[view.message.type](view, conn)

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import subprocess import subprocess
import traceback import traceback
import typing import typing
@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__ from . import __version__
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message, Response, View from .misc import Message, Response, View
from .processors import run_processor from .processors import run_processor
@ -75,17 +75,16 @@ def register_route(*paths: str) -> Callable:
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
with self.database.connection() as conn: config = conn.get_config_all()
config = conn.get_config_all() inboxes = conn.execute('SELECT * FROM inboxes').all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
text = HOME_TEMPLATE.format( text = HOME_TEMPLATE.format(
host = self.config.domain, host = self.config.domain,
note = config['note'], note = config['note'],
count = len(inboxes), count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes) targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
) )
return Response.new(text, ctype='html') return Response.new(text, ctype='html')
@ -103,7 +102,7 @@ class ActorView(View):
self.signer: Signer = None self.signer: Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.domain, host = self.config.domain,
pubkey = self.app.signer.pubkey pubkey = self.app.signer.pubkey
@ -112,37 +111,36 @@ class ActorView(View):
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data(): if response := await self.get_post_data():
return response return response
with self.database.connection() as conn: self.instance = conn.get_inbox(self.actor.shared_inbox)
self.instance = conn.get_inbox(self.actor.shared_inbox) config = conn.get_config_all()
config = conn.get_config_all()
## reject if the actor isn't whitelisted while the whiltelist is enabled ## reject if the actor isn't whitelisted while the whiltelist is enabled
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned ## reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following ## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance: if self.message.type != 'Follow' and not self.instance:
logging.verbose( logging.verbose(
'Rejected actor for trying to post while not following: %s', 'Rejected actor for trying to post while not following: %s',
self.actor.id self.actor.id
) )
return Response.new_error(401, 'access denied', 'json') return Response.new_error(401, 'access denied', 'json')
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self)) await run_processor(self, conn)
return Response.new(status = 202) return Response.new(status = 202)
async def get_post_data(self) -> Response | None: async def get_post_data(self) -> Response | None:
@ -232,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
try: try:
subject = request.query['resource'] subject = request.query['resource']
@ -253,18 +251,18 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response: # pylint: disable=no-self-use
with self.database.connection() as conn: async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
data = { data = {
'name': 'activityrelay', 'name': 'activityrelay',
'version': VERSION, 'version': VERSION,
'protocols': ['activitypub'], 'protocols': ['activitypub'],
'open_regs': not conn.get_config('whitelist-enabled'), 'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1, 'users': 1,
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
} }
if niversion == '2.1': if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay' data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@ -274,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain) data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')