diff --git a/docs/configuration.md b/docs/configuration.md index bef0799..36f0463 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`. database_type: sqlite +### Cache type + +Cache backend to use. Valid values are `database` or `redis` + + cache_type: database + + ### Sqlite File Path Path to the sqlite database file. If the path is not absolute, it is relative to the config file. @@ -47,7 +54,7 @@ directory. In order to use the Postgresql backend, the user and database need to be created first. - sudo -u postgres psql -c "CREATE USER activityrelay" + sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword" sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay" @@ -84,3 +91,47 @@ User to use when logging into the server. Password for the specified user. pass: null + + +## Redis + +### Host + +Hostname, IP address, or unix socket the server is hosted on. + + host: /var/run/postgresql + + +### Port + +Port number the server is listening on. + + port: 5432 + + +### Username + +User to use when logging into the server. + + user: null + + +### Password + +Password for the specified user. + + pass: null + + +### Database Number + +Number of the database to use. + + database: 0 + + +### Prefix + +Text to prefix every key with. It cannot contain a `:` character. + + prefix: activityrelay diff --git a/relay.yaml.example b/relay.yaml.example index 90b9e8f..823feaa 100644 --- a/relay.yaml.example +++ b/relay.yaml.example @@ -7,12 +7,15 @@ listen: 0.0.0.0 # [integer] Port the relay will listen on port: 8080 -# [integer] Number of push workers to start (will get removed in a future update) +# [integer] Number of push workers to start workers: 8 # [string] Database backend to use. Valid values: sqlite, postgres database_type: sqlite +# [string] Cache backend to use. Valid values: database, redis +cache_type: database + # [string] Path to the sqlite database file if the sqlite backend is in use sqlite_path: relay.sqlite3 @@ -33,3 +36,24 @@ postgres: # [string] name of the database to use name: activityrelay + +# settings for the redis caching backend +redis: + + # [string] hostname or unix socket to connect to + host: localhost + + # [integer] port of the server + port: 6379 + + # [string] username to use when logging into the server + user: null + + # [string] password for the server + pass: null + + # [integer] database number to use + database: 0 + + # [string] prefix for keys + prefix: activityrelay diff --git a/relay/application.py b/relay/application.py index 9440098..92bccb5 100644 --- a/relay/application.py +++ b/relay/application.py @@ -12,6 +12,7 @@ from aputils.signer import Signer from datetime import datetime, timedelta from . import logger as logging +from .cache import get_cache from .config import Config from .database import get_database from .http_client import HttpClient @@ -19,8 +20,9 @@ from .misc import check_open_port from .views import VIEWS if typing.TYPE_CHECKING: - from tinysql import Database + from tinysql import Database, Row from typing import Any + from .cache import Cache from .misc import Message @@ -38,6 +40,7 @@ class Application(web.Application): self['config'] = Config(cfgpath, load = True) self['database'] = get_database(self.config) self['client'] = HttpClient() + self['cache'] = get_cache(self) self['workers'] = [] self['last_worker'] = 0 @@ -48,6 +51,11 @@ class Application(web.Application): self.router.add_view(path, view) + @property + def cache(self) -> Cache: + return self['cache'] + + @property def client(self) -> HttpClient: return self['client'] @@ -87,13 +95,13 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message) -> None: + def push_message(self, inbox: str, message: Message, instance: Row) -> None: if self.config.workers <= 0: - asyncio.ensure_future(self.client.post(inbox, message)) + asyncio.ensure_future(self.client.post(inbox, message, instance)) return worker = self['workers'][self['last_worker']] - worker.queue.put((inbox, message)) + worker.queue.put((inbox, message, instance)) self['last_worker'] += 1 @@ -186,10 +194,10 @@ class PushWorker(threading.Thread): while self.app['running']: try: - inbox, message = self.queue.get(block=True, timeout=0.25) + inbox, message, instance = self.queue.get(block=True, timeout=0.25) self.queue.task_done() logging.verbose('New push from Thread-%i', threading.get_ident()) - await self.client.post(inbox, message) + await self.client.post(inbox, message, instance) except queue.Empty: pass diff --git a/relay/cache.py b/relay/cache.py new file mode 100644 index 0000000..9763e8c --- /dev/null +++ b/relay/cache.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import json +import os +import typing + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from redis import Redis + +from .misc import Message, boolean + +if typing.TYPE_CHECKING: + from typing import Any + from collections.abc import Callable, Iterator + from tinysql import Database + from .application import Application + + +# todo: implement more caching backends + + +BACKENDS: dict[str, Cache] = {} +CONVERTERS: dict[str, tuple[Callable, Callable]] = { + 'str': (str, str), + 'int': (str, int), + 'bool': (str, boolean), + 'json': (json.dumps, json.loads), + 'message': (lambda x: x.to_json(), Message.parse) +} + + +def get_cache(app: Application) -> Cache: + return BACKENDS[app.config.ca_type](app) + + +def register_cache(backend: type[Cache]) -> type[Cache]: + BACKENDS[backend.name] = backend + return backend + + +def serialize_value(value: Any, value_type: str = 'str') -> str: + if isinstance(value, str): + return value + + return CONVERTERS[value_type][0](value) + + +def deserialize_value(value: str, value_type: str = 'str') -> Any: + return CONVERTERS[value_type][1](value) + + +@dataclass +class Item: + namespace: str + key: str + value: Any + value_type: str + updated: datetime + + + def __post_init__(self): + if isinstance(self.updated, str): + self.updated = datetime.fromisoformat(self.updated) + + + @classmethod + def from_data(cls: type[Item], *args) -> Item: + data = cls(*args) + data.value = deserialize_value(data.value, data.value_type) + return data + + + def older_than(self, hours: int) -> bool: + delta = datetime.now(tz = timezone.utc) - self.updated + return (delta.total_seconds()) > hours * 3600 + + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +class Cache(ABC): + name: str = 'null' + + + def __init__(self, app: Application): + self.app = app + self.setup() + + @abstractmethod + def get(self, namespace: str, key: str) -> Item: + ... + + + @abstractmethod + def get_keys(self, namespace: str) -> Iterator[str]: + ... + + + @abstractmethod + def get_namespaces(self) -> Iterator[str]: + ... + + + @abstractmethod + def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: + ... + + + @abstractmethod + def delete(self, namespace: str, key: str) -> None: + ... + + + @abstractmethod + def setup(self) -> None: + ... + + + def set_item(self, item: Item) -> Item: + return self.set( + item.namespace, + item.key, + item.value, + item.type + ) + + + def delete_item(self, item: Item) -> None: + self.delete(item.namespace, item.key) + + +@register_cache +class SqlCache(Cache): + name: str = 'database' + + + @property + def _db(self) -> Database: + return self.app.database + + + def get(self, namespace: str, key: str) -> Item: + params = { + 'namespace': namespace, + 'key': key + } + + with self._db.connection() as conn: + with conn.exec_statement('get-cache-item', params) as cur: + if not (row := cur.one()): + raise KeyError(f'{namespace}:{key}') + + return Item.from_data(*tuple(row.values())) + + + def get_keys(self, namespace: str) -> Iterator[str]: + with self._db.connection() as conn: + for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}): + yield row['key'] + + + def get_namespaces(self) -> Iterator[str]: + with self._db.connection() as conn: + for row in conn.exec_statement('get-cache-namespaces', None): + yield row['namespace'] + + + def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: + params = { + 'namespace': namespace, + 'key': key, + 'value': serialize_value(value, value_type), + 'type': value_type, + 'date': datetime.now(tz = timezone.utc) + } + + with self._db.connection() as conn: + for row in conn.exec_statement('set-cache-item', params): + return Item.from_data(*tuple(row.values())) + + + def delete(self, namespace: str, key: str) -> None: + params = { + 'namespace': namespace, + 'key': key + } + + with self._db.connection() as conn: + with conn.exec_statement('del-cache-item', params): + pass + + + def setup(self) -> None: + with self._db.connection() as conn: + with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None): + pass + + +@register_cache +class RedisCache(Cache): + name: str = 'redis' + _rd: Redis + + + @property + def prefix(self) -> str: + return self.app.config.rd_prefix + + + def get_key_name(self, namespace: str, key: str) -> str: + return f'{self.prefix}:{namespace}:{key}' + + + def get(self, namespace: str, key: str) -> Item: + key_name = self.get_key_name(namespace, key) + + if not (raw_value := self._rd.get(key_name)): + raise KeyError(f'{namespace}:{key}') + + value_type, updated, value = raw_value.split(':', 2) + return Item.from_data( + namespace, + key, + value, + value_type, + datetime.fromtimestamp(float(updated), tz = timezone.utc) + ) + + + def get_keys(self, namespace: str) -> Iterator[str]: + for key in self._rd.keys(self.get_key_name(namespace, '*')): + *_, key_name = key.split(':', 2) + yield key_name + + + def get_namespaces(self) -> Iterator[str]: + namespaces = [] + + for key in self._rd.keys(f'{self.prefix}:*'): + _, namespace, _ = key.split(':', 2) + + if namespace not in namespaces: + namespaces.append(namespace) + yield namespace + + + def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None: + date = datetime.now(tz = timezone.utc).timestamp() + value = serialize_value(value, value_type) + + self._rd.set( + self.get_key_name(namespace, key), + f'{value_type}:{date}:{value}' + ) + + + def delete(self, namespace: str, key: str) -> None: + self._rd.delete(self.get_key_name(namespace, key)) + + + def setup(self) -> None: + options = { + 'client_name': f'ActivityRelay_{self.app.config.domain}', + 'decode_responses': True, + 'username': self.app.config.rd_user, + 'password': self.app.config.rd_pass, + 'db': self.app.config.rd_database + } + + if os.path.exists(self.app.config.rd_host): + options['unix_socket_path'] = self.app.config.rd_host + + else: + options['host'] = self.app.config.rd_host + options['port'] = self.app.config.rd_port + + self._rd = Redis(**options) diff --git a/relay/config.py b/relay/config.py index 58c2eb8..eff6215 100644 --- a/relay/config.py +++ b/relay/config.py @@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = { 'domain': 'relay.example.com', 'workers': len(os.sched_getaffinity(0)), 'db_type': 'sqlite', + 'ca_type': 'database', 'sq_path': 'relay.sqlite3', + 'pg_host': '/var/run/postgresql', 'pg_port': 5432, 'pg_user': getpass.getuser(), 'pg_pass': None, - 'pg_name': 'activityrelay' + 'pg_name': 'activityrelay', + + 'rd_host': 'localhost', + 'rd_port': 6379, + 'rd_user': None, + 'rd_pass': None, + 'rd_database': 0, + 'rd_prefix': 'activityrelay' } if IS_DOCKER: @@ -40,13 +49,22 @@ class Config: self.domain = None self.workers = None self.db_type = None + self.ca_type = None self.sq_path = None + self.pg_host = None self.pg_port = None self.pg_user = None self.pg_pass = None self.pg_name = None + self.rd_host = None + self.rd_port = None + self.rd_user = None + self.rd_pass = None + self.rd_database = None + self.rd_prefix = None + if load: try: self.load() @@ -92,6 +110,7 @@ class Config: with self.path.open('r', encoding = 'UTF-8') as fd: config = yaml.load(fd, **options) pgcfg = config.get('postgresql', {}) + rdcfg = config.get('redis', {}) if not config: raise ValueError('Config is empty') @@ -106,18 +125,25 @@ class Config: self.set('port', config.get('port', DEFAULTS['port'])) self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path'])) + self.set('workers', config.get('workers', DEFAULTS['workers'])) self.set('domain', config.get('domain', DEFAULTS['domain'])) self.set('db_type', config.get('database_type', DEFAULTS['db_type'])) + self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type'])) for key in DEFAULTS: - if not key.startswith('pg'): - continue + if key.startswith('pg'): + try: + self.set(key, pgcfg[key[3:]]) - try: - self.set(key, pgcfg[key[3:]]) + except KeyError: + continue - except KeyError: - continue + elif key.startswith('rd'): + try: + self.set(key, rdcfg[key[3:]]) + + except KeyError: + continue def reset(self) -> None: @@ -132,7 +158,9 @@ class Config: 'listen': self.listen, 'port': self.port, 'domain': self.domain, + 'workers': self.workers, 'database_type': self.db_type, + 'cache_type': self.ca_type, 'sqlite_path': self.sq_path, 'postgres': { 'host': self.pg_host, @@ -140,6 +168,14 @@ class Config: 'user': self.pg_user, 'pass': self.pg_pass, 'name': self.pg_name + }, + 'redis': { + 'host': self.rd_host, + 'port': self.rd_port, + 'user': self.rd_user, + 'pass': self.rd_pass, + 'database': self.rd_database, + 'refix': self.rd_prefix } } diff --git a/relay/data/statements.sql b/relay/data/statements.sql index a262feb..9bcea41 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -77,3 +77,64 @@ RETURNING * -- name: del-domain-whitelist DELETE FROM whitelist WHERE domain = :domain + + +-- cache functions -- + +-- name: create-cache-table-sqlite +CREATE TABLE IF NOT EXISTS cache ( + id INTEGER PRIMARY KEY UNIQUE, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + "value" TEXT, + type TEXT DEFAULT 'str', + updated TIMESTAMP NOT NULL, + UNIQUE(namespace, key) +) + +-- name: create-cache-table-postgres +CREATE TABLE IF NOT EXISTS cache ( + id SERIAL PRIMARY KEY, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + "value" TEXT, + type TEXT DEFAULT 'str', + updated TIMESTAMP NOT NULL, + UNIQUE(namespace, key) +) + + +-- name: get-cache-item +SELECT * FROM cache +WHERE namespace = :namespace and key = :key + + +-- name: get-cache-keys +SELECT key FROM cache +WHERE namespace = :namespace + + +-- name: get-cache-namespaces +SELECT DISTINCT namespace FROM cache + + +-- name: set-cache-item +INSERT INTO cache (namespace, key, value, type, updated) +VALUES (:namespace, :key, :value, :type, :date) +ON CONFLICT (namespace, key) DO +UPDATE SET value = :value, type = :type, updated = :date +RETURNING * + + +-- name: del-cache-item +DELETE FROM cache +WHERE namespace = :namespace and key = :key + + +-- name: del-cache-namespace +DELETE FROM cache +WHERE namespace = :namespace + + +-- name: del-cache-all +DELETE FROM cache diff --git a/relay/http_client.py b/relay/http_client.py index 5040059..6bed4ae 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import traceback import typing @@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError from aputils.objects import Nodeinfo, WellKnownNodeinfo -from cachetools import LRUCache from json.decoder import JSONDecodeError from urllib.parse import urlparse @@ -16,7 +16,11 @@ from . import logger as logging from .misc import MIMETYPES, Message, get_app if typing.TYPE_CHECKING: + from aputils import Signer + from tinysql import Row from typing import Any + from .application import Application + from .cache import Cache HEADERS = { @@ -26,8 +30,7 @@ HEADERS = { class HttpClient: - def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024): - self.cache = LRUCache(cache_size) + def __init__(self, limit: int = 100, timeout: int = 10): self.limit = limit self.timeout = timeout self._conn = None @@ -43,6 +46,21 @@ class HttpClient: await self.close() + @property + def app(self) -> Application: + return get_app() + + + @property + def cache(self) -> Cache: + return self.app.cache + + + @property + def signer(self) -> Signer: + return self.app.signer + + async def open(self) -> None: if self._session: return @@ -74,8 +92,8 @@ class HttpClient: async def get(self, # pylint: disable=too-many-branches url: str, sign_headers: bool = False, - loads: callable | None = None, - force: bool = False) -> Message | dict | None: + loads: callable = json.loads, + force: bool = False) -> dict | None: await self.open() @@ -85,13 +103,20 @@ class HttpClient: except ValueError: pass - if not force and url in self.cache: - return self.cache[url] + if not force: + try: + item = self.cache.get('request', url) + + if not item.older_than(48): + return loads(item.value) + + except KeyError: + logging.verbose('Failed to fetch cached data for url: %s', url) headers = {} if sign_headers: - get_app().signer.sign_headers('GET', url, algorithm = 'original') + self.signer.sign_headers('GET', url, algorithm = 'original') try: logging.debug('Fetching resource: %s', url) @@ -101,32 +126,22 @@ class HttpClient: if resp.status == 202: return None - if resp.status != 200: - logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(await resp.read()) - return None + data = await resp.read() - if loads: - message = await resp.json(loads=loads) + if resp.status != 200: + logging.verbose('Received error when requesting %s: %i', url, resp.status) + logging.debug(await resp.read()) + return None - elif resp.content_type == MIMETYPES['activity']: - message = await resp.json(loads = Message.parse) + message = loads(data) + self.cache.set('request', url, data.decode('utf-8'), 'str') + logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4)) - elif resp.content_type == MIMETYPES['json']: - message = await resp.json() - - else: - logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type) - logging.debug('Response: %s', await resp.read()) - return None - - logging.debug('%s >> resp %s', url, message.to_json(4)) - - self.cache[url] = message - return message + return message except JSONDecodeError: logging.verbose('Failed to parse JSON') + return None except ClientSSLError: logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) @@ -140,13 +155,9 @@ class HttpClient: return None - async def post(self, url: str, message: Message) -> None: + async def post(self, url: str, message: Message, instance: Row | None = None) -> None: await self.open() - # todo: cache inboxes to avoid opening a db connection - with get_app().database.connection() as conn: - instance = conn.get_inbox(url) - ## Using the old algo by default is probably a better idea right now # pylint: disable=consider-ternary-expression if instance and instance['software'] in {'mastodon'}: diff --git a/relay/misc.py b/relay/misc.py index d7e96d8..831db38 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -17,6 +17,7 @@ if typing.TYPE_CHECKING: from collections.abc import Coroutine, Generator from typing import Any from .application import Application + from .cache import Cache from .config import Config from .database import Database from .http_client import HttpClient @@ -268,6 +269,11 @@ class View(AbstractView): return self.request.app + @property + def cache(self) -> Cache: + return self.app.cache + + @property def client(self) -> HttpClient: return self.app.client diff --git a/relay/processors.py b/relay/processors.py index 8cf7722..671adb3 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -3,8 +3,6 @@ from __future__ import annotations import tinysql import typing -from cachetools import LRUCache - from . import logger as logging from .misc import Message @@ -12,9 +10,6 @@ if typing.TYPE_CHECKING: from .views import ActorView -cache = LRUCache(1024) - - def person_check(actor: str, software: str) -> bool: # pleroma and akkoma may use Person for the actor type for some reason # akkoma changed this in 3.6.0 @@ -29,31 +24,39 @@ def person_check(actor: str, software: str) -> bool: async def handle_relay(view: ActorView) -> None: - if view.message.object_id in cache: + try: + view.cache.get('handle-relay', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id) return + except KeyError: + pass + message = Message.new_announce(view.config.domain, view.message.object_id) - cache[view.message.object_id] = message.id + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') logging.debug('>> relay: %s', message) with view.database.connection() as conn: for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message) + view.app.push_message(inbox, message, view.instance) async def handle_forward(view: ActorView) -> None: - if view.message.id in cache: - logging.verbose('already forwarded %s', view.message.id) + try: + view.cache.get('handle-relay', view.message.object_id) + logging.verbose('already forwarded %s', view.message.object_id) return + except KeyError: + pass + message = Message.new_announce(view.config.domain, view.message) - cache[view.message.id] = message.id + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') logging.debug('>> forward: %s', message) with view.database.connection() as conn: for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message) + view.app.push_message(inbox, message, view.instance) async def handle_follow(view: ActorView) -> None: diff --git a/relay/views.py b/relay/views.py index 8b84c02..64c1d57 100644 --- a/relay/views.py +++ b/relay/views.py @@ -19,6 +19,7 @@ if typing.TYPE_CHECKING: from aiohttp.web import Request from aputils.signer import Signer from collections.abc import Callable + from tinysql import Row VIEWS = [] @@ -98,7 +99,7 @@ class ActorView(View): self.signature: Signature = None self.message: Message = None self.actor: Message = None - self.instance: dict[str, str] = None + self.instance: Row = None self.signer: Signer = None @@ -168,7 +169,11 @@ class ActorView(View): logging.verbose('actor not in message') return Response.new_error(400, 'no actor in message', 'json') - self.actor = await self.client.get(self.signature.keyid, sign_headers = True) + self.actor = await self.client.get( + self.signature.keyid, + sign_headers = True, + loads = Message.parse + ) if not self.actor: # ld signatures aren't handled atm, so just ignore it diff --git a/requirements.txt b/requirements.txt index 3fb5e40..0eb5add 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ aiohttp>=3.9.1 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz -cachetools>=5.2.0 click>=8.1.2 +hiredis==2.3.2 pyyaml>=6.0 +redis==5.0.1 tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz importlib_resources==6.1.1;python_version<'3.9'