mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-24 15:31:08 +00:00
create a new database connection for each request
This commit is contained in:
parent
e6f30ddf64
commit
2fcaea85ae
|
@ -53,7 +53,7 @@ SOFTWARE = (
|
||||||
|
|
||||||
def check_alphanumeric(text: str) -> str:
|
def check_alphanumeric(text: str) -> str:
|
||||||
if not text.isalnum():
|
if not text.isalnum():
|
||||||
raise click.BadParameter(f'String not alphanumeric')
|
raise click.BadParameter('String not alphanumeric')
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,8 @@ from uuid import uuid4
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Coroutine, Generator
|
from collections.abc import Coroutine, Generator
|
||||||
from typing import Any
|
from tinysql import Connection
|
||||||
|
from typing import Any, Awaitable
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .cache import Cache
|
from .cache import Cache
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
@ -234,6 +235,9 @@ class Response(AiohttpResponse):
|
||||||
|
|
||||||
|
|
||||||
class View(AbstractView):
|
class View(AbstractView):
|
||||||
|
conn: Connection
|
||||||
|
|
||||||
|
|
||||||
def __await__(self) -> Generator[Response]:
|
def __await__(self) -> Generator[Response]:
|
||||||
if (self.request.method) not in METHODS:
|
if (self.request.method) not in METHODS:
|
||||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||||
|
@ -241,7 +245,12 @@ class View(AbstractView):
|
||||||
if not (handler := self.handlers.get(self.request.method)):
|
if not (handler := self.handlers.get(self.request.method)):
|
||||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
|
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
|
||||||
|
|
||||||
return handler(self.request, **self.request.match_info).__await__()
|
return self._run_handler(handler).__await__()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_handler(self, handler: Awaitable) -> Response:
|
||||||
|
with self.database.config.connection_class(self.database) as conn:
|
||||||
|
return await handler(self.request,conn, **self.request.match_info)
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import tinysql
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
|
from .database.connection import Connection
|
||||||
from .misc import Message
|
from .misc import Message
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
|
@ -23,7 +23,7 @@ def person_check(actor: str, software: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def handle_relay(view: ActorView) -> None:
|
async def handle_relay(view: ActorView, conn: Connection) -> None:
|
||||||
try:
|
try:
|
||||||
view.cache.get('handle-relay', view.message.object_id)
|
view.cache.get('handle-relay', view.message.object_id)
|
||||||
logging.verbose('already relayed %s', view.message.object_id)
|
logging.verbose('already relayed %s', view.message.object_id)
|
||||||
|
@ -33,15 +33,15 @@ async def handle_relay(view: ActorView) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
message = Message.new_announce(view.config.domain, view.message.object_id)
|
message = Message.new_announce(view.config.domain, view.message.object_id)
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
|
||||||
logging.debug('>> relay: %s', message)
|
logging.debug('>> relay: %s', message)
|
||||||
|
|
||||||
with view.database.connection() as conn:
|
|
||||||
for inbox in conn.distill_inboxes(view.message):
|
for inbox in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(inbox, message, view.instance)
|
view.app.push_message(inbox, message, view.instance)
|
||||||
|
|
||||||
|
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||||
|
|
||||||
async def handle_forward(view: ActorView) -> None:
|
|
||||||
|
async def handle_forward(view: ActorView, conn: Connection) -> None:
|
||||||
try:
|
try:
|
||||||
view.cache.get('handle-relay', view.message.object_id)
|
view.cache.get('handle-relay', view.message.object_id)
|
||||||
logging.verbose('already forwarded %s', view.message.object_id)
|
logging.verbose('already forwarded %s', view.message.object_id)
|
||||||
|
@ -51,19 +51,18 @@ async def handle_forward(view: ActorView) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
message = Message.new_announce(view.config.domain, view.message)
|
message = Message.new_announce(view.config.domain, view.message)
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
|
||||||
logging.debug('>> forward: %s', message)
|
logging.debug('>> forward: %s', message)
|
||||||
|
|
||||||
with view.database.connection() as conn:
|
|
||||||
for inbox in conn.distill_inboxes(view.message):
|
for inbox in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(inbox, message, view.instance)
|
view.app.push_message(inbox, message, view.instance)
|
||||||
|
|
||||||
|
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||||
|
|
||||||
async def handle_follow(view: ActorView) -> None:
|
|
||||||
|
async def handle_follow(view: ActorView, conn: Connection) -> None:
|
||||||
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
||||||
software = nodeinfo.sw_name if nodeinfo else None
|
software = nodeinfo.sw_name if nodeinfo else None
|
||||||
|
|
||||||
with view.database.connection() as conn:
|
|
||||||
# reject if software used by actor is banned
|
# reject if software used by actor is banned
|
||||||
if conn.get_software_ban(software):
|
if conn.get_software_ban(software):
|
||||||
view.app.push_message(
|
view.app.push_message(
|
||||||
|
@ -103,6 +102,7 @@ async def handle_follow(view: ActorView) -> None:
|
||||||
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
|
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
with conn.transaction():
|
||||||
view.instance = conn.put_inbox(
|
view.instance = conn.put_inbox(
|
||||||
view.actor.domain,
|
view.actor.domain,
|
||||||
view.actor.shared_inbox,
|
view.actor.shared_inbox,
|
||||||
|
@ -135,13 +135,13 @@ async def handle_follow(view: ActorView) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_undo(view: ActorView) -> None:
|
async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||||
## If the object is not a Follow, forward it
|
## If the object is not a Follow, forward it
|
||||||
if view.message.object['type'] != 'Follow':
|
if view.message.object['type'] != 'Follow':
|
||||||
await handle_forward(view)
|
await handle_forward(view, conn)
|
||||||
return
|
return
|
||||||
|
|
||||||
with view.database.connection() as conn:
|
with conn.transaction():
|
||||||
if not conn.del_inbox(view.actor.id):
|
if not conn.del_inbox(view.actor.id):
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Failed to delete "%s" with follow ID "%s"',
|
'Failed to delete "%s" with follow ID "%s"',
|
||||||
|
@ -170,7 +170,7 @@ processors = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def run_processor(view: ActorView) -> None:
|
async def run_processor(view: ActorView, conn: Connection) -> None:
|
||||||
if view.message.type not in processors:
|
if view.message.type not in processors:
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Message type "%s" from actor cannot be handled: %s',
|
'Message type "%s" from actor cannot be handled: %s',
|
||||||
|
@ -181,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if view.instance:
|
if view.instance:
|
||||||
|
with conn.transaction():
|
||||||
if not view.instance['software']:
|
if not view.instance['software']:
|
||||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
||||||
with view.database.connection() as conn:
|
|
||||||
view.instance = conn.update_inbox(
|
view.instance = conn.update_inbox(
|
||||||
view.instance['inbox'],
|
view.instance['inbox'],
|
||||||
software = nodeinfo.sw_name
|
software = nodeinfo.sw_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if not view.instance['actor']:
|
if not view.instance['actor']:
|
||||||
with view.database.connection() as conn:
|
|
||||||
view.instance = conn.update_inbox(
|
view.instance = conn.update_inbox(
|
||||||
view.instance['inbox'],
|
view.instance['inbox'],
|
||||||
actor = view.actor.id
|
actor = view.actor.id
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
|
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
|
||||||
await processors[view.message.type](view)
|
await processors[view.message.type](view, conn)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
@ -12,6 +11,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
|
from .database.connection import Connection
|
||||||
from .misc import Message, Response, View
|
from .misc import Message, Response, View
|
||||||
from .processors import run_processor
|
from .processors import run_processor
|
||||||
|
|
||||||
|
@ -75,8 +75,7 @@ def register_route(*paths: str) -> Callable:
|
||||||
|
|
||||||
@register_route('/')
|
@register_route('/')
|
||||||
class HomeView(View):
|
class HomeView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request, conn: Connection) -> Response:
|
||||||
with self.database.connection() as conn:
|
|
||||||
config = conn.get_config_all()
|
config = conn.get_config_all()
|
||||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||||
|
|
||||||
|
@ -103,7 +102,7 @@ class ActorView(View):
|
||||||
self.signer: Signer = None
|
self.signer: Signer = None
|
||||||
|
|
||||||
|
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request, conn: Connection) -> Response:
|
||||||
data = Message.new_actor(
|
data = Message.new_actor(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
pubkey = self.app.signer.pubkey
|
pubkey = self.app.signer.pubkey
|
||||||
|
@ -112,11 +111,10 @@ class ActorView(View):
|
||||||
return Response.new(data, ctype='activity')
|
return Response.new(data, ctype='activity')
|
||||||
|
|
||||||
|
|
||||||
async def post(self, request: Request) -> Response:
|
async def post(self, request: Request, conn: Connection) -> Response:
|
||||||
if response := await self.get_post_data():
|
if response := await self.get_post_data():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
with self.database.connection() as conn:
|
|
||||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
||||||
config = conn.get_config_all()
|
config = conn.get_config_all()
|
||||||
|
|
||||||
|
@ -141,7 +139,7 @@ class ActorView(View):
|
||||||
|
|
||||||
logging.debug('>> payload %s', self.message.to_json(4))
|
logging.debug('>> payload %s', self.message.to_json(4))
|
||||||
|
|
||||||
asyncio.ensure_future(run_processor(self))
|
await run_processor(self, conn)
|
||||||
return Response.new(status = 202)
|
return Response.new(status = 202)
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,7 +230,7 @@ class ActorView(View):
|
||||||
|
|
||||||
@register_route('/.well-known/webfinger')
|
@register_route('/.well-known/webfinger')
|
||||||
class WebfingerView(View):
|
class WebfingerView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request, conn: Connection) -> Response:
|
||||||
try:
|
try:
|
||||||
subject = request.query['resource']
|
subject = request.query['resource']
|
||||||
|
|
||||||
|
@ -253,8 +251,8 @@ class WebfingerView(View):
|
||||||
|
|
||||||
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
|
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
|
||||||
class NodeinfoView(View):
|
class NodeinfoView(View):
|
||||||
async def get(self, request: Request, niversion: str) -> Response:
|
# pylint: disable=no-self-use
|
||||||
with self.database.connection() as conn:
|
async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
|
||||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
@ -274,6 +272,6 @@ class NodeinfoView(View):
|
||||||
|
|
||||||
@register_route('/.well-known/nodeinfo')
|
@register_route('/.well-known/nodeinfo')
|
||||||
class WellknownNodeinfoView(View):
|
class WellknownNodeinfoView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request, conn: Connection) -> Response:
|
||||||
data = WellKnownNodeinfo.new_template(self.config.domain)
|
data = WellKnownNodeinfo.new_template(self.config.domain)
|
||||||
return Response.new(data, ctype = 'json')
|
return Response.new(data, ctype = 'json')
|
||||||
|
|
Loading…
Reference in a new issue