create a new database connection for each request

This commit is contained in:
Izalia Mae 2024-02-04 04:53:39 -05:00
parent e6f30ddf64
commit 2fcaea85ae
4 changed files with 128 additions and 122 deletions

View file

@ -53,7 +53,7 @@ SOFTWARE = (
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
raise click.BadParameter(f'String not alphanumeric')
raise click.BadParameter('String not alphanumeric')
return text

View file

@ -15,7 +15,8 @@ from uuid import uuid4
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
from typing import Any
from tinysql import Connection
from typing import Any, Awaitable
from .application import Application
from .cache import Cache
from .config import Config
@ -234,6 +235,9 @@ class Response(AiohttpResponse):
class View(AbstractView):
conn: Connection
def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@ -241,7 +245,12 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()
return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Awaitable) -> Response:
with self.database.config.connection_class(self.database) as conn:
return await handler(self.request,conn, **self.request.match_info)
@cached_property

View file

@ -1,9 +1,9 @@
from __future__ import annotations
import tinysql
import typing
from . import logger as logging
from .database.connection import Connection
from .misc import Message
if typing.TYPE_CHECKING:
@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
return False
async def handle_relay(view: ActorView) -> None:
async def handle_relay(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id)
@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
pass
message = Message.new_announce(view.config.domain, view.message.object_id)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_forward(view: ActorView) -> None:
async def handle_forward(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
@ -51,19 +51,18 @@ async def handle_forward(view: ActorView) -> None:
pass
message = Message.new_announce(view.config.domain, view.message)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_follow(view: ActorView) -> None:
async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn:
# reject if software used by actor is banned
if conn.get_software_ban(software):
view.app.push_message(
@ -103,6 +102,7 @@ async def handle_follow(view: ActorView) -> None:
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
else:
with conn.transaction():
view.instance = conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
@ -135,13 +135,13 @@ async def handle_follow(view: ActorView) -> None:
)
async def handle_undo(view: ActorView) -> None:
async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view)
await handle_forward(view, conn)
return
with view.database.connection() as conn:
with conn.transaction():
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
@ -170,7 +170,7 @@ processors = {
}
async def run_processor(view: ActorView) -> None:
async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return
if view.instance:
with conn.transaction():
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)
await processors[view.message.type](view, conn)

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import subprocess
import traceback
import typing
@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__
from . import logger as logging
from .database.connection import Connection
from .misc import Message, Response, View
from .processors import run_processor
@ -75,8 +75,7 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.connection() as conn:
async def get(self, request: Request, conn: Connection) -> Response:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
@ -103,7 +102,7 @@ class ActorView(View):
self.signer: Signer = None
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
@ -112,11 +111,10 @@ class ActorView(View):
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data():
return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()
@ -141,7 +139,7 @@ class ActorView(View):
logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self))
await run_processor(self, conn)
return Response.new(status = 202)
@ -232,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger')
class WebfingerView(View):
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
try:
subject = request.query['resource']
@ -253,8 +251,8 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn:
# pylint: disable=no-self-use
async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {
@ -274,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')