create a new database connection for each request

This commit is contained in:
Izalia Mae 2024-02-04 04:53:39 -05:00
parent e6f30ddf64
commit 2fcaea85ae
4 changed files with 128 additions and 122 deletions

View file

@ -53,7 +53,7 @@ SOFTWARE = (
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():
raise click.BadParameter(f'String not alphanumeric') raise click.BadParameter('String not alphanumeric')
return text return text

View file

@ -15,7 +15,8 @@ from uuid import uuid4
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator from collections.abc import Coroutine, Generator
from typing import Any from tinysql import Connection
from typing import Any, Awaitable
from .application import Application from .application import Application
from .cache import Cache from .cache import Cache
from .config import Config from .config import Config
@ -234,6 +235,9 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
conn: Connection
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS: if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@ -241,7 +245,12 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Awaitable) -> Response:
with self.database.config.connection_class(self.database) as conn:
return await handler(self.request,conn, **self.request.match_info)
@cached_property @cached_property

View file

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
return False return False
async def handle_relay(view: ActorView) -> None: async def handle_relay(view: ActorView, conn: Connection) -> None:
try: try:
view.cache.get('handle-relay', view.message.object_id) view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id)
@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
pass pass
message = Message.new_announce(view.config.domain, view.message.object_id) message = Message.new_announce(view.config.domain, view.message.object_id)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_forward(view: ActorView) -> None:
async def handle_forward(view: ActorView, conn: Connection) -> None:
try: try:
view.cache.get('handle-relay', view.message.object_id) view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id) logging.verbose('already forwarded %s', view.message.object_id)
@ -51,19 +51,18 @@ async def handle_forward(view: ActorView) -> None:
pass pass
message = Message.new_announce(view.config.domain, view.message) message = Message.new_announce(view.config.domain, view.message)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_follow(view: ActorView) -> None:
async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn:
# reject if software used by actor is banned # reject if software used by actor is banned
if conn.get_software_ban(software): if conn.get_software_ban(software):
view.app.push_message( view.app.push_message(
@ -103,6 +102,7 @@ async def handle_follow(view: ActorView) -> None:
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id) view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
else: else:
with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
view.actor.domain, view.actor.domain,
view.actor.shared_inbox, view.actor.shared_inbox,
@ -135,13 +135,13 @@ async def handle_follow(view: ActorView) -> None:
) )
async def handle_undo(view: ActorView) -> None: async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if view.message.object['type'] != 'Follow':
await handle_forward(view) await handle_forward(view, conn)
return return
with view.database.connection() as conn: with conn.transaction():
if not conn.del_inbox(view.actor.id): if not conn.del_inbox(view.actor.id):
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', 'Failed to delete "%s" with follow ID "%s"',
@ -170,7 +170,7 @@ processors = {
} }
async def run_processor(view: ActorView) -> None: async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors: if view.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', 'Message type "%s" from actor cannot be handled: %s',
@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return return
if view.instance: if view.instance:
with conn.transaction():
if not view.instance['software']: if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance['actor']: if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
actor = view.actor.id actor = view.actor.id
) )
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view) await processors[view.message.type](view, conn)

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import subprocess import subprocess
import traceback import traceback
import typing import typing
@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__ from . import __version__
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message, Response, View from .misc import Message, Response, View
from .processors import run_processor from .processors import run_processor
@ -75,8 +75,7 @@ def register_route(*paths: str) -> Callable:
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
with self.database.connection() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
@ -103,7 +102,7 @@ class ActorView(View):
self.signer: Signer = None self.signer: Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.domain, host = self.config.domain,
pubkey = self.app.signer.pubkey pubkey = self.app.signer.pubkey
@ -112,11 +111,10 @@ class ActorView(View):
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data(): if response := await self.get_post_data():
return response return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all() config = conn.get_config_all()
@ -141,7 +139,7 @@ class ActorView(View):
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self)) await run_processor(self, conn)
return Response.new(status = 202) return Response.new(status = 202)
@ -232,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
try: try:
subject = request.query['resource'] subject = request.query['resource']
@ -253,8 +251,8 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response: # pylint: disable=no-self-use
with self.database.connection() as conn: async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
data = { data = {
@ -274,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain) data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')