Merge branch 'cache' into 'master'

Draft: caching

See merge request pleroma/relay!54
This commit is contained in:
Izalia Mae 2024-02-04 10:21:02 +00:00
commit ed5d45396f
12 changed files with 740 additions and 191 deletions

View file

@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite
### Cache type
Cache backend to use. Valid values are `database` or `redis`
cache_type: database
### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
@ -47,7 +54,7 @@ directory.
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
@ -84,3 +91,47 @@ User to use when logging into the server.
Password for the specified user.
pass: null
## Redis
### Host
Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql
### Port
Port number the server is listening on.
port: 5432
### Username
User to use when logging into the server.
user: null
### Password
Password for the specified user.
pass: null
### Database Number
Number of the database to use.
database: 0
### Prefix
Text to prefix every key with. It cannot contain a `:` character.
prefix: activityrelay

View file

@ -7,12 +7,15 @@ listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
# [integer] Number of push workers to start (will get removed in a future update)
# [integer] Number of push workers to start
workers: 8
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
# [string] Cache backend to use. Valid values: database, redis
cache_type: database
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
@ -33,3 +36,24 @@ postgres:
# [string] name of the database to use
name: activityrelay
# settings for the redis caching backend
redis:
# [string] hostname or unix socket to connect to
host: localhost
# [integer] port of the server
port: 6379
# [string] username to use when logging into the server
user: null
# [string] password for the server
pass: null
# [integer] database number to use
database: 0
# [string] prefix for keys
prefix: activityrelay

View file

@ -12,6 +12,7 @@ from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
from .cache import get_cache
from .config import Config
from .database import get_database
from .http_client import HttpClient
@ -19,8 +20,9 @@ from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from tinysql import Database
from tinysql import Database, Row
from typing import Any
from .cache import Cache
from .misc import Message
@ -38,6 +40,7 @@ class Application(web.Application):
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['cache'] = get_cache(self)
self['workers'] = []
self['last_worker'] = 0
@ -48,6 +51,11 @@ class Application(web.Application):
self.router.add_view(path, view)
@property
def cache(self) -> Cache:
return self['cache']
@property
def client(self) -> HttpClient:
return self['client']
@ -87,13 +95,13 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message) -> None:
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
if self.config.workers <= 0:
asyncio.ensure_future(self.client.post(inbox, message))
asyncio.ensure_future(self.client.post(inbox, message, instance))
return
worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message))
worker.queue.put((inbox, message, instance))
self['last_worker'] += 1
@ -186,10 +194,10 @@ class PushWorker(threading.Thread):
while self.app['running']:
try:
inbox, message = self.queue.get(block=True, timeout=0.25)
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
self.queue.task_done()
logging.verbose('New push from Thread-%i', threading.get_ident())
await self.client.post(inbox, message)
await self.client.post(inbox, message, instance)
except queue.Empty:
pass

288
relay/cache.py Normal file
View file

@ -0,0 +1,288 @@
from __future__ import annotations
import json
import os
import typing
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from redis import Redis
from .database import get_database
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Any
from collections.abc import Callable, Iterator
from tinysql import Database
from .application import Application
# todo: implement more caching backends
BACKENDS: dict[str, Cache] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
}
def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app)
def register_cache(backend: type[Cache]) -> type[Cache]:
BACKENDS[backend.name] = backend
return backend
def serialize_value(value: Any, value_type: str = 'str') -> str:
if isinstance(value, str):
return value
return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any:
return CONVERTERS[value_type][1](value)
@dataclass
class Item:
namespace: str
key: str
value: Any
value_type: str
updated: datetime
def __post_init__(self):
if isinstance(self.updated, str):
self.updated = datetime.fromisoformat(self.updated)
@classmethod
def from_data(cls: type[Item], *args) -> Item:
data = cls(*args)
data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
return data
def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class Cache(ABC):
name: str = 'null'
def __init__(self, app: Application):
self.app = app
self.setup()
@abstractmethod
def get(self, namespace: str, key: str) -> Item:
...
@abstractmethod
def get_keys(self, namespace: str) -> Iterator[str]:
...
@abstractmethod
def get_namespaces(self) -> Iterator[str]:
...
@abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
...
@abstractmethod
def delete(self, namespace: str, key: str) -> None:
...
@abstractmethod
def setup(self) -> None:
...
def set_item(self, item: Item) -> Item:
return self.set(
item.namespace,
item.key,
item.value,
item.type
)
def delete_item(self, item: Item) -> None:
self.delete(item.namespace, item.key)
@register_cache
class SqlCache(Cache):
name: str = 'database'
def __init__(self, app: Application):
self._db = get_database(app.config)
Cache.__init__(self, app)
def get(self, namespace: str, key: str) -> Item:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('get-cache-item', params) as cur:
if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}')
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def get_keys(self, namespace: str) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
yield row['key']
def get_namespaces(self) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-namespaces', None):
yield row['namespace']
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
params = {
'namespace': namespace,
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': datetime.now(tz = timezone.utc)
}
with self._db.connection() as conn:
with conn.exec_statement('set-cache-item', params) as conn:
row = conn.one()
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('del-cache-item', params):
pass
def setup(self) -> None:
with self._db.connection() as conn:
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
pass
@register_cache
class RedisCache(Cache):
name: str = 'redis'
_rd: Redis
@property
def prefix(self) -> str:
return self.app.config.rd_prefix
def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}'
def get(self, namespace: str, key: str) -> Item:
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2)
return Item.from_data(
namespace,
key,
value,
value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc)
)
def get_keys(self, namespace: str) -> Iterator[str]:
for key in self._rd.keys(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
yield key_name
def get_namespaces(self) -> Iterator[str]:
namespaces = []
for key in self._rd.keys(f'{self.prefix}:*'):
_, namespace, _ = key.split(':', 2)
if namespace not in namespaces:
namespaces.append(namespace)
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
date = datetime.now(tz = timezone.utc).timestamp()
value = serialize_value(value, value_type)
self._rd.set(
self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}'
)
def delete(self, namespace: str, key: str) -> None:
self._rd.delete(self.get_key_name(namespace, key))
def setup(self) -> None:
options = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
'password': self.app.config.rd_pass,
'db': self.app.config.rd_database
}
if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host
else:
options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port
self._rd = Redis(**options)

View file

@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = {
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'ca_type': 'database',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
'pg_name': 'activityrelay',
'rd_host': 'localhost',
'rd_port': 6379,
'rd_user': None,
'rd_pass': None,
'rd_database': 0,
'rd_prefix': 'activityrelay'
}
if IS_DOCKER:
@ -40,13 +49,22 @@ class Config:
self.domain = None
self.workers = None
self.db_type = None
self.ca_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
self.rd_host = None
self.rd_port = None
self.rd_user = None
self.rd_pass = None
self.rd_database = None
self.rd_prefix = None
if load:
try:
self.load()
@ -92,6 +110,7 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
if not config:
raise ValueError('Config is empty')
@ -106,19 +125,26 @@ class Config:
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue
if key.startswith('pg'):
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue
elif key.startswith('rd'):
try:
self.set(key, rdcfg[key[3:]])
except KeyError:
continue
def reset(self) -> None:
for key, value in DEFAULTS.items():
@ -132,7 +158,9 @@ class Config:
'listen': self.listen,
'port': self.port,
'domain': self.domain,
'workers': self.workers,
'database_type': self.db_type,
'cache_type': self.ca_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
@ -140,6 +168,14 @@ class Config:
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
},
'redis': {
'host': self.rd_host,
'port': self.rd_port,
'user': self.rd_user,
'pass': self.rd_pass,
'database': self.rd_database,
'refix': self.rd_prefix
}
}

View file

@ -77,3 +77,64 @@ RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain
-- cache functions --
-- name: create-cache-table-sqlite
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY UNIQUE,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: create-cache-table-postgres
CREATE TABLE IF NOT EXISTS cache (
id SERIAL PRIMARY KEY,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: get-cache-item
SELECT * FROM cache
WHERE namespace = :namespace and key = :key
-- name: get-cache-keys
SELECT key FROM cache
WHERE namespace = :namespace
-- name: get-cache-namespaces
SELECT DISTINCT namespace FROM cache
-- name: set-cache-item
INSERT INTO cache (namespace, key, value, type, updated)
VALUES (:namespace, :key, :value, :type, :date)
ON CONFLICT (namespace, key) DO
UPDATE SET value = :value, type = :type, updated = :date
RETURNING *
-- name: del-cache-item
DELETE FROM cache
WHERE namespace = :namespace and key = :key
-- name: del-cache-namespace
DELETE FROM cache
WHERE namespace = :namespace
-- name: del-cache-all
DELETE FROM cache

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
import traceback
import typing
@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo
from cachetools import LRUCache
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
@ -16,7 +16,11 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from aputils import Signer
from tinysql import Row
from typing import Any
from .application import Application
from .cache import Cache
HEADERS = {
@ -26,8 +30,7 @@ HEADERS = {
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
self.cache = LRUCache(cache_size)
def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
self._conn = None
@ -43,6 +46,21 @@ class HttpClient:
await self.close()
@property
def app(self) -> Application:
return get_app()
@property
def cache(self) -> Cache:
return self.app.cache
@property
def signer(self) -> Signer:
return self.app.signer
async def open(self) -> None:
if self._session:
return
@ -74,8 +92,8 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
loads: callable | None = None,
force: bool = False) -> Message | dict | None:
loads: callable = json.loads,
force: bool = False) -> dict | None:
await self.open()
@ -85,13 +103,20 @@ class HttpClient:
except ValueError:
pass
if not force and url in self.cache:
return self.cache[url]
if not force:
try:
item = self.cache.get('request', url)
if not item.older_than(48):
return loads(item.value)
except KeyError:
logging.verbose('Failed to fetch cached data for url: %s', url)
headers = {}
if sign_headers:
get_app().signer.sign_headers('GET', url, algorithm = 'original')
self.signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@ -101,32 +126,22 @@ class HttpClient:
if resp.status == 202:
return None
data = await resp.read()
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read())
return None
if loads:
message = await resp.json(loads=loads)
message = loads(data)
self.cache.set('request', url, data.decode('utf-8'), 'str')
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads = Message.parse)
elif resp.content_type == MIMETYPES['json']:
message = await resp.json()
else:
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
logging.debug('Response: %s', await resp.read())
return None
logging.debug('%s >> resp %s', url, message.to_json(4))
self.cache[url] = message
return message
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
return None
except ClientSSLError:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
@ -140,13 +155,9 @@ class HttpClient:
return None
async def post(self, url: str, message: Message) -> None:
async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
await self.open()
# todo: cache inboxes to avoid opening a db connection
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:

View file

@ -51,6 +51,13 @@ SOFTWARE = (
)
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
raise click.BadParameter('String not alphanumeric')
return text
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay')
@ -113,8 +120,9 @@ def cli_setup(ctx: click.Context) -> None:
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
'What IP address, hostname, or unix socket does the server listen on?',
default = ctx.obj.config.pg_host,
type = int
)
ctx.obj.config.pg_port = click.prompt(
@ -135,6 +143,48 @@ def cli_setup(ctx: click.Context) -> None:
default = ctx.obj.config.pg_pass or ""
) or None
ctx.obj.config.ca_type = click.prompt(
'Which caching backend?',
default = ctx.obj.config.ca_type,
type = click.Choice(['database', 'redis'], case_sensitive = False)
)
if ctx.obj.config.ca_type == 'redis':
ctx.obj.config.rd_host = click.prompt(
'What IP address, hostname, or unix socket does the server listen on?',
default = ctx.obj.config.rd_host
)
ctx.obj.config.rd_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.rd_port,
type = int
)
ctx.obj.config.rd_user = click.prompt(
'Which user will authenticate with the server',
default = ctx.obj.config.rd_user
)
ctx.obj.config.rd_pass = click.prompt(
'User password',
hide_input = True,
show_default = False,
default = ctx.obj.config.rd_pass or ""
) or None
ctx.obj.config.rd_database = click.prompt(
'Which database number to use?',
default = ctx.obj.config.rd_database,
type = int
)
ctx.obj.config.rd_prefix = click.prompt(
'What text should each cache key be prefixed with?',
default = ctx.obj.config.rd_database,
type = check_alphanumeric
)
ctx.obj.config.save()
config = {
@ -364,7 +414,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor
)
asyncio.run(http.post(inbox, message))
asyncio.run(http.post(inbox, message, None, inbox_data))
click.echo(f'Sent follow message to actor: {actor}')
@ -405,7 +455,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
}
)
asyncio.run(http.post(inbox, message))
asyncio.run(http.post(inbox, message, inbox_data))
click.echo(f'Sent unfollow message to: {actor}')

View file

@ -15,8 +15,10 @@ from uuid import uuid4
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
from typing import Any
from tinysql import Connection
from typing import Any, Awaitable
from .application import Application
from .cache import Cache
from .config import Config
from .database import Database
from .http_client import HttpClient
@ -240,7 +242,15 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()
return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Awaitable) -> Response:
with self.database.config.connection_class(self.database) as conn:
# todo: remove on next tinysql release
conn.open()
return await handler(self.request, conn, **self.request.match_info)
@cached_property
@ -268,6 +278,11 @@ class View(AbstractView):
return self.request.app
@property
def cache(self) -> Cache:
return self.app.cache
@property
def client(self) -> HttpClient:
return self.app.client

View file

@ -1,20 +1,15 @@
from __future__ import annotations
import tinysql
import typing
from cachetools import LRUCache
from . import logger as logging
from .database.connection import Connection
from .misc import Message
if typing.TYPE_CHECKING:
from .views import ActorView
cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
@ -28,39 +23,46 @@ def person_check(actor: str, software: str) -> bool:
return False
async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache:
async def handle_relay(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_forward(view: ActorView) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
async def handle_forward(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_follow(view: ActorView) -> None:
async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn:
# reject if software used by actor is banned
if conn.get_software_ban(software):
view.app.push_message(
@ -97,14 +99,11 @@ async def handle_follow(view: ActorView) -> None:
return
if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
else:
conn.put_inbox(
with conn.transaction():
view.instance = conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
@ -119,7 +118,8 @@ async def handle_follow(view: ActorView) -> None:
actor = view.actor.id,
followid = view.message.id,
accept = True
)
),
view.instance
)
# Are Akkoma and Pleroma the only two that expect a follow back?
@ -130,17 +130,18 @@ async def handle_follow(view: ActorView) -> None:
Message.new_follow(
host = view.config.domain,
actor = view.actor.id
)
),
view.instance
)
async def handle_undo(view: ActorView) -> None:
async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view)
await handle_forward(view, conn)
return
with view.database.connection() as conn:
with conn.transaction():
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
@ -154,7 +155,8 @@ async def handle_undo(view: ActorView) -> None:
host = view.config.domain,
actor = view.actor.id,
follow = view.message
)
),
view.instance
)
@ -168,7 +170,7 @@ processors = {
}
async def run_processor(view: ActorView) -> None:
async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -179,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return
if view.instance:
with conn.transaction():
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)
await processors[view.message.type](view, conn)

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import subprocess
import traceback
import typing
@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__
from . import logger as logging
from .database.connection import Connection
from .misc import Message, Response, View
from .processors import run_processor
@ -19,6 +19,7 @@ if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from collections.abc import Callable
from tinysql import Row
VIEWS = []
@ -74,8 +75,7 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.connection() as conn:
async def get(self, request: Request, conn: Connection) -> Response:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
@ -98,11 +98,11 @@ class ActorView(View):
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.instance: Row = None
self.signer: Signer = None
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
@ -111,11 +111,10 @@ class ActorView(View):
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data():
return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()
@ -140,7 +139,7 @@ class ActorView(View):
logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self))
await run_processor(self, conn)
return Response.new(status = 202)
@ -168,7 +167,11 @@ class ActorView(View):
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
self.actor = await self.client.get(
self.signature.keyid,
sign_headers = True,
loads = Message.parse
)
if not self.actor:
# ld signatures aren't handled atm, so just ignore it
@ -227,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger')
class WebfingerView(View):
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
try:
subject = request.query['resource']
@ -248,8 +251,8 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn:
# pylint: disable=no-self-use
async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {
@ -269,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')

View file

@ -1,8 +1,9 @@
aiohttp>=3.9.1
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
hiredis==2.3.2
pyyaml>=6.0
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz
redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz
importlib_resources==6.1.1;python_version<'3.9'