mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-24 15:31:08 +00:00
create a new database connection for each request
This commit is contained in:
parent
e6f30ddf64
commit
2fcaea85ae
|
@ -53,7 +53,7 @@ SOFTWARE = (
|
|||
|
||||
def check_alphanumeric(text: str) -> str:
|
||||
if not text.isalnum():
|
||||
raise click.BadParameter(f'String not alphanumeric')
|
||||
raise click.BadParameter('String not alphanumeric')
|
||||
|
||||
return text
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@ from uuid import uuid4
|
|||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Coroutine, Generator
|
||||
from typing import Any
|
||||
from tinysql import Connection
|
||||
from typing import Any, Awaitable
|
||||
from .application import Application
|
||||
from .cache import Cache
|
||||
from .config import Config
|
||||
|
@ -234,6 +235,9 @@ class Response(AiohttpResponse):
|
|||
|
||||
|
||||
class View(AbstractView):
|
||||
conn: Connection
|
||||
|
||||
|
||||
def __await__(self) -> Generator[Response]:
|
||||
if (self.request.method) not in METHODS:
|
||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||
|
@ -241,7 +245,12 @@ class View(AbstractView):
|
|||
if not (handler := self.handlers.get(self.request.method)):
|
||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
|
||||
|
||||
return handler(self.request, **self.request.match_info).__await__()
|
||||
return self._run_handler(handler).__await__()
|
||||
|
||||
|
||||
async def _run_handler(self, handler: Awaitable) -> Response:
|
||||
with self.database.config.connection_class(self.database) as conn:
|
||||
return await handler(self.request,conn, **self.request.match_info)
|
||||
|
||||
|
||||
@cached_property
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tinysql
|
||||
import typing
|
||||
|
||||
from . import logger as logging
|
||||
from .database.connection import Connection
|
||||
from .misc import Message
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
async def handle_relay(view: ActorView) -> None:
|
||||
async def handle_relay(view: ActorView, conn: Connection) -> None:
|
||||
try:
|
||||
view.cache.get('handle-relay', view.message.object_id)
|
||||
logging.verbose('already relayed %s', view.message.object_id)
|
||||
|
@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
|
|||
pass
|
||||
|
||||
message = Message.new_announce(view.config.domain, view.message.object_id)
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
logging.debug('>> relay: %s', message)
|
||||
|
||||
with view.database.connection() as conn:
|
||||
for inbox in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(inbox, message, view.instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
|
||||
async def handle_forward(view: ActorView) -> None:
|
||||
|
||||
async def handle_forward(view: ActorView, conn: Connection) -> None:
|
||||
try:
|
||||
view.cache.get('handle-relay', view.message.object_id)
|
||||
logging.verbose('already forwarded %s', view.message.object_id)
|
||||
|
@ -51,19 +51,18 @@ async def handle_forward(view: ActorView) -> None:
|
|||
pass
|
||||
|
||||
message = Message.new_announce(view.config.domain, view.message)
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
logging.debug('>> forward: %s', message)
|
||||
|
||||
with view.database.connection() as conn:
|
||||
for inbox in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(inbox, message, view.instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
|
||||
async def handle_follow(view: ActorView) -> None:
|
||||
|
||||
async def handle_follow(view: ActorView, conn: Connection) -> None:
|
||||
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
||||
software = nodeinfo.sw_name if nodeinfo else None
|
||||
|
||||
with view.database.connection() as conn:
|
||||
# reject if software used by actor is banned
|
||||
if conn.get_software_ban(software):
|
||||
view.app.push_message(
|
||||
|
@ -103,6 +102,7 @@ async def handle_follow(view: ActorView) -> None:
|
|||
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
|
||||
|
||||
else:
|
||||
with conn.transaction():
|
||||
view.instance = conn.put_inbox(
|
||||
view.actor.domain,
|
||||
view.actor.shared_inbox,
|
||||
|
@ -135,13 +135,13 @@ async def handle_follow(view: ActorView) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def handle_undo(view: ActorView) -> None:
|
||||
async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||
## If the object is not a Follow, forward it
|
||||
if view.message.object['type'] != 'Follow':
|
||||
await handle_forward(view)
|
||||
await handle_forward(view, conn)
|
||||
return
|
||||
|
||||
with view.database.connection() as conn:
|
||||
with conn.transaction():
|
||||
if not conn.del_inbox(view.actor.id):
|
||||
logging.verbose(
|
||||
'Failed to delete "%s" with follow ID "%s"',
|
||||
|
@ -170,7 +170,7 @@ processors = {
|
|||
}
|
||||
|
||||
|
||||
async def run_processor(view: ActorView) -> None:
|
||||
async def run_processor(view: ActorView, conn: Connection) -> None:
|
||||
if view.message.type not in processors:
|
||||
logging.verbose(
|
||||
'Message type "%s" from actor cannot be handled: %s',
|
||||
|
@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
|
|||
return
|
||||
|
||||
if view.instance:
|
||||
with conn.transaction():
|
||||
if not view.instance['software']:
|
||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
||||
with view.database.connection() as conn:
|
||||
view.instance = conn.update_inbox(
|
||||
view.instance['inbox'],
|
||||
software = nodeinfo.sw_name
|
||||
)
|
||||
|
||||
if not view.instance['actor']:
|
||||
with view.database.connection() as conn:
|
||||
view.instance = conn.update_inbox(
|
||||
view.instance['inbox'],
|
||||
actor = view.actor.id
|
||||
)
|
||||
|
||||
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
|
||||
await processors[view.message.type](view)
|
||||
await processors[view.message.type](view, conn)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import traceback
|
||||
import typing
|
||||
|
@ -12,6 +11,7 @@ from pathlib import Path
|
|||
|
||||
from . import __version__
|
||||
from . import logger as logging
|
||||
from .database.connection import Connection
|
||||
from .misc import Message, Response, View
|
||||
from .processors import run_processor
|
||||
|
||||
|
@ -75,8 +75,7 @@ def register_route(*paths: str) -> Callable:
|
|||
|
||||
@register_route('/')
|
||||
class HomeView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.connection() as conn:
|
||||
async def get(self, request: Request, conn: Connection) -> Response:
|
||||
config = conn.get_config_all()
|
||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||
|
||||
|
@ -103,7 +102,7 @@ class ActorView(View):
|
|||
self.signer: Signer = None
|
||||
|
||||
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: Request, conn: Connection) -> Response:
|
||||
data = Message.new_actor(
|
||||
host = self.config.domain,
|
||||
pubkey = self.app.signer.pubkey
|
||||
|
@ -112,11 +111,10 @@ class ActorView(View):
|
|||
return Response.new(data, ctype='activity')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
async def post(self, request: Request, conn: Connection) -> Response:
|
||||
if response := await self.get_post_data():
|
||||
return response
|
||||
|
||||
with self.database.connection() as conn:
|
||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
||||
config = conn.get_config_all()
|
||||
|
||||
|
@ -141,7 +139,7 @@ class ActorView(View):
|
|||
|
||||
logging.debug('>> payload %s', self.message.to_json(4))
|
||||
|
||||
asyncio.ensure_future(run_processor(self))
|
||||
await run_processor(self, conn)
|
||||
return Response.new(status = 202)
|
||||
|
||||
|
||||
|
@ -232,7 +230,7 @@ class ActorView(View):
|
|||
|
||||
@register_route('/.well-known/webfinger')
|
||||
class WebfingerView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: Request, conn: Connection) -> Response:
|
||||
try:
|
||||
subject = request.query['resource']
|
||||
|
||||
|
@ -253,8 +251,8 @@ class WebfingerView(View):
|
|||
|
||||
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
|
||||
class NodeinfoView(View):
|
||||
async def get(self, request: Request, niversion: str) -> Response:
|
||||
with self.database.connection() as conn:
|
||||
# pylint: disable=no-self-use
|
||||
async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
|
||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||
|
||||
data = {
|
||||
|
@ -274,6 +272,6 @@ class NodeinfoView(View):
|
|||
|
||||
@register_route('/.well-known/nodeinfo')
|
||||
class WellknownNodeinfoView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: Request, conn: Connection) -> Response:
|
||||
data = WellKnownNodeinfo.new_template(self.config.domain)
|
||||
return Response.new(data, ctype = 'json')
|
||||
|
|
Loading…
Reference in a new issue