diff --git a/docs/configuration.md b/docs/configuration.md index bef0799..36f0463 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`. database_type: sqlite +### Cache type + +Cache backend to use. Valid values are `database` or `redis` + + cache_type: database + + ### Sqlite File Path Path to the sqlite database file. If the path is not absolute, it is relative to the config file. @@ -47,7 +54,7 @@ directory. In order to use the Postgresql backend, the user and database need to be created first. - sudo -u postgres psql -c "CREATE USER activityrelay" + sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword" sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay" @@ -84,3 +91,47 @@ User to use when logging into the server. Password for the specified user. pass: null + + +## Redis + +### Host + +Hostname, IP address, or unix socket the server is hosted on. + + host: /var/run/postgresql + + +### Port + +Port number the server is listening on. + + port: 5432 + + +### Username + +User to use when logging into the server. + + user: null + + +### Password + +Password for the specified user. + + pass: null + + +### Database Number + +Number of the database to use. + + database: 0 + + +### Prefix + +Text to prefix every key with. It cannot contain a `:` character. + + prefix: activityrelay diff --git a/relay.yaml.example b/relay.yaml.example index 90b9e8f..823feaa 100644 --- a/relay.yaml.example +++ b/relay.yaml.example @@ -7,12 +7,15 @@ listen: 0.0.0.0 # [integer] Port the relay will listen on port: 8080 -# [integer] Number of push workers to start (will get removed in a future update) +# [integer] Number of push workers to start workers: 8 # [string] Database backend to use. Valid values: sqlite, postgres database_type: sqlite +# [string] Cache backend to use. Valid values: database, redis +cache_type: database + # [string] Path to the sqlite database file if the sqlite backend is in use sqlite_path: relay.sqlite3 @@ -33,3 +36,24 @@ postgres: # [string] name of the database to use name: activityrelay + +# settings for the redis caching backend +redis: + + # [string] hostname or unix socket to connect to + host: localhost + + # [integer] port of the server + port: 6379 + + # [string] username to use when logging into the server + user: null + + # [string] password for the server + pass: null + + # [integer] database number to use + database: 0 + + # [string] prefix for keys + prefix: activityrelay diff --git a/relay/application.py b/relay/application.py index 9440098..80efed9 100644 --- a/relay/application.py +++ b/relay/application.py @@ -1,17 +1,20 @@ from __future__ import annotations import asyncio -import queue +import os import signal -import threading -import traceback +import subprocess +import sys +import time import typing from aiohttp import web from aputils.signer import Signer from datetime import datetime, timedelta +from gunicorn.app.wsgiapp import WSGIApplication from . import logger as logging +from .cache import get_cache from .config import Config from .database import get_database from .http_client import HttpClient @@ -19,8 +22,10 @@ from .misc import check_open_port from .views import VIEWS if typing.TYPE_CHECKING: - from tinysql import Database + from collections.abc import Awaitable + from tinysql import Database, Row from typing import Any + from .cache import Cache from .misc import Message @@ -29,25 +34,34 @@ if typing.TYPE_CHECKING: class Application(web.Application): DEFAULT: Application = None - def __init__(self, cfgpath: str): + def __init__(self, cfgpath: str, gunicorn: bool = False): web.Application.__init__(self) Application.DEFAULT = self + self['proc'] = None self['signer'] = None + self['start_time'] = None + self['config'] = Config(cfgpath, load = True) self['database'] = get_database(self.config) self['client'] = HttpClient() + self['cache'] = get_cache(self) - self['workers'] = [] - self['last_worker'] = 0 - self['start_time'] = None - self['running'] = False + if not gunicorn: + return + + self.on_response_prepare.append(handle_access_log) for path, view in VIEWS: self.router.add_view(path, view) + @property + def cache(self) -> Cache: + return self['cache'] + + @property def client(self) -> HttpClient: return self['client'] @@ -87,18 +101,17 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message) -> None: - if self.config.workers <= 0: - asyncio.ensure_future(self.client.post(inbox, message)) - return + def push_message(self, inbox: str, message: Message, instance: Row) -> None: + asyncio.ensure_future(self.client.post(inbox, message, instance)) - worker = self['workers'][self['last_worker']] - worker.queue.put((inbox, message)) - self['last_worker'] += 1 + def run(self, dev: bool = False) -> None: + self.start(dev) - if self['last_worker'] >= len(self['workers']): - self['last_worker'] = 0 + while self['proc'] and self['proc'].poll() is None: + time.sleep(0.1) + + self.stop() def set_signal_handler(self, startup: bool) -> None: @@ -111,91 +124,101 @@ class Application(web.Application): pass - def run(self) -> None: - if not check_open_port(self.config.listen, self.config.port): - logging.error('A server is already running on port %i', self.config.port) + + def start(self, dev: bool = False) -> None: + if self['proc']: return - for view in VIEWS: - self.router.add_view(*view) + if not check_open_port(self.config.listen, self.config.port): + logging.error('Server already running on %s:%s', self.config.listen, self.config.port) + return - logging.info( - 'Starting webserver at %s (%s:%i)', - self.config.domain, - self.config.listen, - self.config.port - ) + cmd = [ + sys.executable, '-m', 'gunicorn', + 'relay.application:main_gunicorn', + '--bind', f'{self.config.listen}:{self.config.port}', + '--worker-class', 'aiohttp.GunicornWebWorker', + '--workers', str(self.config.workers), + '--env', f'CONFIG_FILE={self.config.path}' + ] - asyncio.run(self.handle_run()) - - - def stop(self, *_: Any) -> None: - self['running'] = False - - - async def handle_run(self) -> None: - self['running'] = True + if dev: + cmd.append('--reload') self.set_signal_handler(True) - - if self.config.workers > 0: - for _ in range(self.config.workers): - worker = PushWorker(self) - worker.start() - - self['workers'].append(worker) - - runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') - await runner.setup() - - site = web.TCPSite( - runner, - host = self.config.listen, - port = self.config.port, - reuse_address = True - ) - - await site.start() - self['start_time'] = datetime.now() - - while self['running']: - await asyncio.sleep(0.25) - - await site.stop() - await self.client.close() - - self['start_time'] = None - self['running'] = False - self['workers'].clear() + self['proc'] = subprocess.Popen(cmd) # pylint: disable=consider-using-with -class PushWorker(threading.Thread): + def stop(self, *_) -> None: + if not self['proc']: + return + + self['proc'].terminate() + time_wait = 0.0 + + while self['proc'].poll() is None: + time.sleep(0.1) + time_wait += 0.1 + + if time_wait >= 5.0: + self['proc'].kill() + break + + self.set_signal_handler(False) + self['proc'] = None + + +# not used, but keeping just in case +class GunicornRunner(WSGIApplication): def __init__(self, app: Application): - threading.Thread.__init__(self) self.app = app - self.queue = queue.Queue() - self.client = None + self.app_uri = 'relay.application:main_gunicorn' + self.options = { + 'bind': f'{app.config.listen}:{app.config.port}', + 'worker_class': 'aiohttp.GunicornWebWorker', + 'workers': app.config.workers, + 'raw_env': f'CONFIG_FILE={app.config.path}' + } + + WSGIApplication.__init__(self) - def run(self) -> None: - asyncio.run(self.handle_queue()) + def load_config(self): + for key, value in self.options.items(): + self.cfg.set(key, value) - async def handle_queue(self) -> None: - self.client = HttpClient() + def run(self): + logging.info('Starting webserver for %s', self.app.config.domain) + WSGIApplication.run(self) - while self.app['running']: - try: - inbox, message = self.queue.get(block=True, timeout=0.25) - self.queue.task_done() - logging.verbose('New push from Thread-%i', threading.get_ident()) - await self.client.post(inbox, message) - except queue.Empty: - pass +async def handle_access_log(request: web.Request, response: web.Response) -> None: + address = request.headers.get( + 'X-Forwarded-For', + request.headers.get( + 'X-Real-Ip', + request.remote + ) + ) - ## make sure an exception doesn't bring down the worker - except Exception: - traceback.print_exc() + logging.info( + '%s "%s %s" %i %i "%s"', + address, + request.method, + request.path, + response.status, + len(response.body), + request.headers.get('User-Agent', 'n/a') + ) - await self.client.close() + +async def main_gunicorn(): + try: + app = Application(os.environ['CONFIG_FILE'], gunicorn = True) + + except KeyError: + logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?') + raise + + return app diff --git a/relay/cache.py b/relay/cache.py new file mode 100644 index 0000000..9c28108 --- /dev/null +++ b/relay/cache.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import json +import os +import typing + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from redis import Redis + +from .database import get_database +from .misc import Message, boolean + +if typing.TYPE_CHECKING: + from typing import Any + from collections.abc import Callable, Iterator + from tinysql import Database + from .application import Application + + +# todo: implement more caching backends + + +BACKENDS: dict[str, Cache] = {} +CONVERTERS: dict[str, tuple[Callable, Callable]] = { + 'str': (str, str), + 'int': (str, int), + 'bool': (str, boolean), + 'json': (json.dumps, json.loads), + 'message': (lambda x: x.to_json(), Message.parse) +} + + +def get_cache(app: Application) -> Cache: + return BACKENDS[app.config.ca_type](app) + + +def register_cache(backend: type[Cache]) -> type[Cache]: + BACKENDS[backend.name] = backend + return backend + + +def serialize_value(value: Any, value_type: str = 'str') -> str: + if isinstance(value, str): + return value + + return CONVERTERS[value_type][0](value) + + +def deserialize_value(value: str, value_type: str = 'str') -> Any: + return CONVERTERS[value_type][1](value) + + +@dataclass +class Item: + namespace: str + key: str + value: Any + value_type: str + updated: datetime + + + def __post_init__(self): + if isinstance(self.updated, str): + self.updated = datetime.fromisoformat(self.updated) + + + @classmethod + def from_data(cls: type[Item], *args) -> Item: + data = cls(*args) + data.value = deserialize_value(data.value, data.value_type) + + if not isinstance(data.updated, datetime): + data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) + + return data + + + def older_than(self, hours: int) -> bool: + delta = datetime.now(tz = timezone.utc) - self.updated + return (delta.total_seconds()) > hours * 3600 + + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +class Cache(ABC): + name: str = 'null' + + + def __init__(self, app: Application): + self.app = app + self.setup() + + @abstractmethod + def get(self, namespace: str, key: str) -> Item: + ... + + + @abstractmethod + def get_keys(self, namespace: str) -> Iterator[str]: + ... + + + @abstractmethod + def get_namespaces(self) -> Iterator[str]: + ... + + + @abstractmethod + def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: + ... + + + @abstractmethod + def delete(self, namespace: str, key: str) -> None: + ... + + + @abstractmethod + def setup(self) -> None: + ... + + + def set_item(self, item: Item) -> Item: + return self.set( + item.namespace, + item.key, + item.value, + item.type + ) + + + def delete_item(self, item: Item) -> None: + self.delete(item.namespace, item.key) + + +@register_cache +class SqlCache(Cache): + name: str = 'database' + + + def __init__(self, app: Application): + self._db = get_database(app.config) + Cache.__init__(self, app) + + + def get(self, namespace: str, key: str) -> Item: + params = { + 'namespace': namespace, + 'key': key + } + + with self._db.connection() as conn: + with conn.exec_statement('get-cache-item', params) as cur: + if not (row := cur.one()): + raise KeyError(f'{namespace}:{key}') + + row.pop('id', None) + return Item.from_data(*tuple(row.values())) + + + def get_keys(self, namespace: str) -> Iterator[str]: + with self._db.connection() as conn: + for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}): + yield row['key'] + + + def get_namespaces(self) -> Iterator[str]: + with self._db.connection() as conn: + for row in conn.exec_statement('get-cache-namespaces', None): + yield row['namespace'] + + + def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: + params = { + 'namespace': namespace, + 'key': key, + 'value': serialize_value(value, value_type), + 'type': value_type, + 'date': datetime.now(tz = timezone.utc) + } + + with self._db.connection() as conn: + with conn.exec_statement('set-cache-item', params) as conn: + row = conn.one() + row.pop('id', None) + return Item.from_data(*tuple(row.values())) + + + def delete(self, namespace: str, key: str) -> None: + params = { + 'namespace': namespace, + 'key': key + } + + with self._db.connection() as conn: + with conn.exec_statement('del-cache-item', params): + pass + + + def setup(self) -> None: + with self._db.connection() as conn: + with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None): + pass + + +@register_cache +class RedisCache(Cache): + name: str = 'redis' + _rd: Redis + + + @property + def prefix(self) -> str: + return self.app.config.rd_prefix + + + def get_key_name(self, namespace: str, key: str) -> str: + return f'{self.prefix}:{namespace}:{key}' + + + def get(self, namespace: str, key: str) -> Item: + key_name = self.get_key_name(namespace, key) + + if not (raw_value := self._rd.get(key_name)): + raise KeyError(f'{namespace}:{key}') + + value_type, updated, value = raw_value.split(':', 2) + return Item.from_data( + namespace, + key, + value, + value_type, + datetime.fromtimestamp(float(updated), tz = timezone.utc) + ) + + + def get_keys(self, namespace: str) -> Iterator[str]: + for key in self._rd.keys(self.get_key_name(namespace, '*')): + *_, key_name = key.split(':', 2) + yield key_name + + + def get_namespaces(self) -> Iterator[str]: + namespaces = [] + + for key in self._rd.keys(f'{self.prefix}:*'): + _, namespace, _ = key.split(':', 2) + + if namespace not in namespaces: + namespaces.append(namespace) + yield namespace + + + def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None: + date = datetime.now(tz = timezone.utc).timestamp() + value = serialize_value(value, value_type) + + self._rd.set( + self.get_key_name(namespace, key), + f'{value_type}:{date}:{value}' + ) + + + def delete(self, namespace: str, key: str) -> None: + self._rd.delete(self.get_key_name(namespace, key)) + + + def setup(self) -> None: + options = { + 'client_name': f'ActivityRelay_{self.app.config.domain}', + 'decode_responses': True, + 'username': self.app.config.rd_user, + 'password': self.app.config.rd_pass, + 'db': self.app.config.rd_database + } + + if os.path.exists(self.app.config.rd_host): + options['unix_socket_path'] = self.app.config.rd_host + + else: + options['host'] = self.app.config.rd_host + options['port'] = self.app.config.rd_port + + self._rd = Redis(**options) diff --git a/relay/config.py b/relay/config.py index 58c2eb8..eff6215 100644 --- a/relay/config.py +++ b/relay/config.py @@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = { 'domain': 'relay.example.com', 'workers': len(os.sched_getaffinity(0)), 'db_type': 'sqlite', + 'ca_type': 'database', 'sq_path': 'relay.sqlite3', + 'pg_host': '/var/run/postgresql', 'pg_port': 5432, 'pg_user': getpass.getuser(), 'pg_pass': None, - 'pg_name': 'activityrelay' + 'pg_name': 'activityrelay', + + 'rd_host': 'localhost', + 'rd_port': 6379, + 'rd_user': None, + 'rd_pass': None, + 'rd_database': 0, + 'rd_prefix': 'activityrelay' } if IS_DOCKER: @@ -40,13 +49,22 @@ class Config: self.domain = None self.workers = None self.db_type = None + self.ca_type = None self.sq_path = None + self.pg_host = None self.pg_port = None self.pg_user = None self.pg_pass = None self.pg_name = None + self.rd_host = None + self.rd_port = None + self.rd_user = None + self.rd_pass = None + self.rd_database = None + self.rd_prefix = None + if load: try: self.load() @@ -92,6 +110,7 @@ class Config: with self.path.open('r', encoding = 'UTF-8') as fd: config = yaml.load(fd, **options) pgcfg = config.get('postgresql', {}) + rdcfg = config.get('redis', {}) if not config: raise ValueError('Config is empty') @@ -106,18 +125,25 @@ class Config: self.set('port', config.get('port', DEFAULTS['port'])) self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path'])) + self.set('workers', config.get('workers', DEFAULTS['workers'])) self.set('domain', config.get('domain', DEFAULTS['domain'])) self.set('db_type', config.get('database_type', DEFAULTS['db_type'])) + self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type'])) for key in DEFAULTS: - if not key.startswith('pg'): - continue + if key.startswith('pg'): + try: + self.set(key, pgcfg[key[3:]]) - try: - self.set(key, pgcfg[key[3:]]) + except KeyError: + continue - except KeyError: - continue + elif key.startswith('rd'): + try: + self.set(key, rdcfg[key[3:]]) + + except KeyError: + continue def reset(self) -> None: @@ -132,7 +158,9 @@ class Config: 'listen': self.listen, 'port': self.port, 'domain': self.domain, + 'workers': self.workers, 'database_type': self.db_type, + 'cache_type': self.ca_type, 'sqlite_path': self.sq_path, 'postgres': { 'host': self.pg_host, @@ -140,6 +168,14 @@ class Config: 'user': self.pg_user, 'pass': self.pg_pass, 'name': self.pg_name + }, + 'redis': { + 'host': self.rd_host, + 'port': self.rd_port, + 'user': self.rd_user, + 'pass': self.rd_pass, + 'database': self.rd_database, + 'refix': self.rd_prefix } } diff --git a/relay/data/statements.sql b/relay/data/statements.sql index a262feb..9bcea41 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -77,3 +77,64 @@ RETURNING * -- name: del-domain-whitelist DELETE FROM whitelist WHERE domain = :domain + + +-- cache functions -- + +-- name: create-cache-table-sqlite +CREATE TABLE IF NOT EXISTS cache ( + id INTEGER PRIMARY KEY UNIQUE, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + "value" TEXT, + type TEXT DEFAULT 'str', + updated TIMESTAMP NOT NULL, + UNIQUE(namespace, key) +) + +-- name: create-cache-table-postgres +CREATE TABLE IF NOT EXISTS cache ( + id SERIAL PRIMARY KEY, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + "value" TEXT, + type TEXT DEFAULT 'str', + updated TIMESTAMP NOT NULL, + UNIQUE(namespace, key) +) + + +-- name: get-cache-item +SELECT * FROM cache +WHERE namespace = :namespace and key = :key + + +-- name: get-cache-keys +SELECT key FROM cache +WHERE namespace = :namespace + + +-- name: get-cache-namespaces +SELECT DISTINCT namespace FROM cache + + +-- name: set-cache-item +INSERT INTO cache (namespace, key, value, type, updated) +VALUES (:namespace, :key, :value, :type, :date) +ON CONFLICT (namespace, key) DO +UPDATE SET value = :value, type = :type, updated = :date +RETURNING * + + +-- name: del-cache-item +DELETE FROM cache +WHERE namespace = :namespace and key = :key + + +-- name: del-cache-namespace +DELETE FROM cache +WHERE namespace = :namespace + + +-- name: del-cache-all +DELETE FROM cache diff --git a/relay/http_client.py b/relay/http_client.py index 5040059..6bed4ae 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import traceback import typing @@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError from aputils.objects import Nodeinfo, WellKnownNodeinfo -from cachetools import LRUCache from json.decoder import JSONDecodeError from urllib.parse import urlparse @@ -16,7 +16,11 @@ from . import logger as logging from .misc import MIMETYPES, Message, get_app if typing.TYPE_CHECKING: + from aputils import Signer + from tinysql import Row from typing import Any + from .application import Application + from .cache import Cache HEADERS = { @@ -26,8 +30,7 @@ HEADERS = { class HttpClient: - def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024): - self.cache = LRUCache(cache_size) + def __init__(self, limit: int = 100, timeout: int = 10): self.limit = limit self.timeout = timeout self._conn = None @@ -43,6 +46,21 @@ class HttpClient: await self.close() + @property + def app(self) -> Application: + return get_app() + + + @property + def cache(self) -> Cache: + return self.app.cache + + + @property + def signer(self) -> Signer: + return self.app.signer + + async def open(self) -> None: if self._session: return @@ -74,8 +92,8 @@ class HttpClient: async def get(self, # pylint: disable=too-many-branches url: str, sign_headers: bool = False, - loads: callable | None = None, - force: bool = False) -> Message | dict | None: + loads: callable = json.loads, + force: bool = False) -> dict | None: await self.open() @@ -85,13 +103,20 @@ class HttpClient: except ValueError: pass - if not force and url in self.cache: - return self.cache[url] + if not force: + try: + item = self.cache.get('request', url) + + if not item.older_than(48): + return loads(item.value) + + except KeyError: + logging.verbose('Failed to fetch cached data for url: %s', url) headers = {} if sign_headers: - get_app().signer.sign_headers('GET', url, algorithm = 'original') + self.signer.sign_headers('GET', url, algorithm = 'original') try: logging.debug('Fetching resource: %s', url) @@ -101,32 +126,22 @@ class HttpClient: if resp.status == 202: return None - if resp.status != 200: - logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(await resp.read()) - return None + data = await resp.read() - if loads: - message = await resp.json(loads=loads) + if resp.status != 200: + logging.verbose('Received error when requesting %s: %i', url, resp.status) + logging.debug(await resp.read()) + return None - elif resp.content_type == MIMETYPES['activity']: - message = await resp.json(loads = Message.parse) + message = loads(data) + self.cache.set('request', url, data.decode('utf-8'), 'str') + logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4)) - elif resp.content_type == MIMETYPES['json']: - message = await resp.json() - - else: - logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type) - logging.debug('Response: %s', await resp.read()) - return None - - logging.debug('%s >> resp %s', url, message.to_json(4)) - - self.cache[url] = message - return message + return message except JSONDecodeError: logging.verbose('Failed to parse JSON') + return None except ClientSSLError: logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) @@ -140,13 +155,9 @@ class HttpClient: return None - async def post(self, url: str, message: Message) -> None: + async def post(self, url: str, message: Message, instance: Row | None = None) -> None: await self.open() - # todo: cache inboxes to avoid opening a db connection - with get_app().database.connection() as conn: - instance = conn.get_inbox(url) - ## Using the old algo by default is probably a better idea right now # pylint: disable=consider-ternary-expression if instance and instance['software'] in {'mastodon'}: diff --git a/relay/logger.py b/relay/logger.py index 1970cab..8aff62d 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -70,7 +70,6 @@ error: Callable = logging.error critical: Callable = logging.critical -logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE') env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() try: @@ -79,22 +78,15 @@ try: except KeyError: env_log_file = None - -try: - log_level = LogLevel[env_log_level] - -except KeyError: - print('Invalid log level:', env_log_level) - log_level = LogLevel['INFO'] - - handlers = [logging.StreamHandler()] if env_log_file: handlers.append(logging.FileHandler(env_log_file)) +logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE') logging.basicConfig( - level = log_level, + level = LogLevel.INFO, format = '[%(asctime)s] %(levelname)s: %(message)s', + datefmt = '%Y-%m-%d %H:%M:%S', handlers = handlers ) diff --git a/relay/manage.py b/relay/manage.py index 69930dc..2a78b6f 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -18,7 +18,7 @@ from .application import Application from .compat import RelayConfig, RelayDatabase from .database import get_database from .database.connection import RELAY_SOFTWARE -from .misc import IS_DOCKER, Message, check_open_port +from .misc import IS_DOCKER, Message if typing.TYPE_CHECKING: from tinysql import Row @@ -51,6 +51,13 @@ SOFTWARE = ( ) +def check_alphanumeric(text: str) -> str: + if not text.isalnum(): + raise click.BadParameter('String not alphanumeric') + + return text + + @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.version_option(version=__version__, prog_name='ActivityRelay') @@ -63,6 +70,11 @@ def cli(ctx: click.Context, config: str) -> None: cli_setup.callback() else: + click.echo( + '[DEPRECATED] Running the relay without the "run" command will be removed in the ' + + 'future.' + ) + cli_run.callback() @@ -113,8 +125,9 @@ def cli_setup(ctx: click.Context) -> None: ) ctx.obj.config.pg_host = click.prompt( - 'What IP address or hostname does the server listen on?', - default = ctx.obj.config.pg_host + 'What IP address, hostname, or unix socket does the server listen on?', + default = ctx.obj.config.pg_host, + type = int ) ctx.obj.config.pg_port = click.prompt( @@ -135,6 +148,48 @@ def cli_setup(ctx: click.Context) -> None: default = ctx.obj.config.pg_pass or "" ) or None + ctx.obj.config.ca_type = click.prompt( + 'Which caching backend?', + default = ctx.obj.config.ca_type, + type = click.Choice(['database', 'redis'], case_sensitive = False) + ) + + if ctx.obj.config.ca_type == 'redis': + ctx.obj.config.rd_host = click.prompt( + 'What IP address, hostname, or unix socket does the server listen on?', + default = ctx.obj.config.rd_host + ) + + ctx.obj.config.rd_port = click.prompt( + 'What port does the server listen on?', + default = ctx.obj.config.rd_port, + type = int + ) + + ctx.obj.config.rd_user = click.prompt( + 'Which user will authenticate with the server', + default = ctx.obj.config.rd_user + ) + + ctx.obj.config.rd_pass = click.prompt( + 'User password', + hide_input = True, + show_default = False, + default = ctx.obj.config.rd_pass or "" + ) or None + + ctx.obj.config.rd_database = click.prompt( + 'Which database number to use?', + default = ctx.obj.config.rd_database, + type = int + ) + + ctx.obj.config.rd_prefix = click.prompt( + 'What text should each cache key be prefixed with?', + default = ctx.obj.config.rd_database, + type = check_alphanumeric + ) + ctx.obj.config.save() config = { @@ -150,8 +205,9 @@ def cli_setup(ctx: click.Context) -> None: @cli.command('run') +@click.option('--dev', '-d', is_flag = True, help = 'Enable worker reloading on code change') @click.pass_context -def cli_run(ctx: click.Context) -> None: +def cli_run(ctx: click.Context, dev: bool = False) -> None: 'Run the relay' if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer: @@ -178,11 +234,7 @@ def cli_run(ctx: click.Context) -> None: click.echo(pip_command) return - if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port): - click.echo(f'Error: A server is already running on port {ctx.obj.config.port}') - return - - ctx.obj.run() + ctx.obj.run(dev) @cli.command('convert') @@ -364,7 +416,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: actor = actor ) - asyncio.run(http.post(inbox, message)) + asyncio.run(http.post(inbox, message, None, inbox_data)) click.echo(f'Sent follow message to actor: {actor}') @@ -405,7 +457,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: } ) - asyncio.run(http.post(inbox, message)) + asyncio.run(http.post(inbox, message, inbox_data)) click.echo(f'Sent unfollow message to: {actor}') diff --git a/relay/misc.py b/relay/misc.py index d7e96d8..e71845d 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -14,9 +14,11 @@ from functools import cached_property from uuid import uuid4 if typing.TYPE_CHECKING: - from collections.abc import Coroutine, Generator + from collections.abc import Awaitable, Coroutine, Generator + from tinysql import Connection from typing import Any from .application import Application + from .cache import Cache from .config import Config from .database import Database from .http_client import HttpClient @@ -234,13 +236,21 @@ class Response(AiohttpResponse): class View(AbstractView): def __await__(self) -> Generator[Response]: - if (self.request.method) not in METHODS: + if self.request.method not in METHODS: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) if not (handler := self.handlers.get(self.request.method)): raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None - return handler(self.request, **self.request.match_info).__await__() + return self._run_handler(handler).__await__() + + + async def _run_handler(self, handler: Awaitable) -> Response: + with self.database.config.connection_class(self.database) as conn: + # todo: remove on next tinysql release + conn.open() + + return await handler(self.request, conn, **self.request.match_info) @cached_property @@ -268,6 +278,11 @@ class View(AbstractView): return self.request.app + @property + def cache(self) -> Cache: + return self.app.cache + + @property def client(self) -> HttpClient: return self.app.client diff --git a/relay/processors.py b/relay/processors.py index 8cf7722..d9780d1 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,20 +1,15 @@ from __future__ import annotations -import tinysql import typing -from cachetools import LRUCache - from . import logger as logging +from .database.connection import Connection from .misc import Message if typing.TYPE_CHECKING: from .views import ActorView -cache = LRUCache(1024) - - def person_check(actor: str, software: str) -> bool: # pleroma and akkoma may use Person for the actor type for some reason # akkoma changed this in 3.6.0 @@ -28,83 +23,87 @@ def person_check(actor: str, software: str) -> bool: return False -async def handle_relay(view: ActorView) -> None: - if view.message.object_id in cache: +async def handle_relay(view: ActorView, conn: Connection) -> None: + try: + view.cache.get('handle-relay', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id) return + except KeyError: + pass + message = Message.new_announce(view.config.domain, view.message.object_id) - cache[view.message.object_id] = message.id logging.debug('>> relay: %s', message) - with view.database.connection() as conn: - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message) + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message, view.instance) + + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') -async def handle_forward(view: ActorView) -> None: - if view.message.id in cache: - logging.verbose('already forwarded %s', view.message.id) +async def handle_forward(view: ActorView, conn: Connection) -> None: + try: + view.cache.get('handle-relay', view.message.object_id) + logging.verbose('already forwarded %s', view.message.object_id) return + except KeyError: + pass + message = Message.new_announce(view.config.domain, view.message) - cache[view.message.id] = message.id logging.debug('>> forward: %s', message) - with view.database.connection() as conn: - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message) + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message, view.instance) + + view.cache.set('handle-relay', view.message.object_id, message.id, 'str') -async def handle_follow(view: ActorView) -> None: +async def handle_follow(view: ActorView, conn: Connection) -> None: nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) software = nodeinfo.sw_name if nodeinfo else None - with view.database.connection() as conn: - # reject if software used by actor is banned - if conn.get_software_ban(software): - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = False - ) + # reject if software used by actor is banned + if conn.get_software_ban(software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False ) + ) - logging.verbose( - 'Rejected follow from actor for using specific software: actor=%s, software=%s', - view.actor.id, - software + logging.verbose( + 'Rejected follow from actor for using specific software: actor=%s, software=%s', + view.actor.id, + software + ) + + return + + ## reject if the actor is not an instance actor + if person_check(view.actor, software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False ) + ) - return + logging.verbose('Non-application actor tried to follow: %s', view.actor.id) + return - ## reject if the actor is not an instance actor - if person_check(view.actor, software): - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = False - ) - ) + if conn.get_inbox(view.actor.shared_inbox): + view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id) - logging.verbose('Non-application actor tried to follow: %s', view.actor.id) - return - - if conn.get_inbox(view.actor.shared_inbox): - data = {'followid': view.message.id} - statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox) - - with conn.query(statement): - pass - - else: - conn.put_inbox( + else: + with conn.transaction(): + view.instance = conn.put_inbox( view.actor.domain, view.actor.shared_inbox, view.actor.id, @@ -112,35 +111,37 @@ async def handle_follow(view: ActorView) -> None: software ) + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = True + ), + view.instance + ) + + # Are Akkoma and Pleroma the only two that expect a follow back? + # Ignoring only Mastodon for now + if software != 'mastodon': view.app.push_message( view.actor.shared_inbox, - Message.new_response( + Message.new_follow( host = view.config.domain, - actor = view.actor.id, - followid = view.message.id, - accept = True - ) + actor = view.actor.id + ), + view.instance ) - # Are Akkoma and Pleroma the only two that expect a follow back? - # Ignoring only Mastodon for now - if software != 'mastodon': - view.app.push_message( - view.actor.shared_inbox, - Message.new_follow( - host = view.config.domain, - actor = view.actor.id - ) - ) - -async def handle_undo(view: ActorView) -> None: +async def handle_undo(view: ActorView, conn: Connection) -> None: ## If the object is not a Follow, forward it if view.message.object['type'] != 'Follow': - await handle_forward(view) + await handle_forward(view, conn) return - with view.database.connection() as conn: + with conn.transaction(): if not conn.del_inbox(view.actor.id): logging.verbose( 'Failed to delete "%s" with follow ID "%s"', @@ -154,7 +155,8 @@ async def handle_undo(view: ActorView) -> None: host = view.config.domain, actor = view.actor.id, follow = view.message - ) + ), + view.instance ) @@ -168,7 +170,7 @@ processors = { } -async def run_processor(view: ActorView) -> None: +async def run_processor(view: ActorView, conn: Connection) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -179,20 +181,19 @@ async def run_processor(view: ActorView) -> None: return if view.instance: - if not view.instance['software']: - if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): - with view.database.connection() as conn: + with conn.transaction(): + if not view.instance['software']: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): view.instance = conn.update_inbox( view.instance['inbox'], software = nodeinfo.sw_name ) - if not view.instance['actor']: - with view.database.connection() as conn: + if not view.instance['actor']: view.instance = conn.update_inbox( view.instance['inbox'], actor = view.actor.id ) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - await processors[view.message.type](view) + await processors[view.message.type](view, conn) diff --git a/relay/views.py b/relay/views.py index 8b84c02..cb648a2 100644 --- a/relay/views.py +++ b/relay/views.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import subprocess import traceback import typing @@ -12,6 +11,7 @@ from pathlib import Path from . import __version__ from . import logger as logging +from .database.connection import Connection from .misc import Message, Response, View from .processors import run_processor @@ -19,6 +19,7 @@ if typing.TYPE_CHECKING: from aiohttp.web import Request from aputils.signer import Signer from collections.abc import Callable + from tinysql import Row VIEWS = [] @@ -74,17 +75,16 @@ def register_route(*paths: str) -> Callable: @register_route('/') class HomeView(View): - async def get(self, request: Request) -> Response: - with self.database.connection() as conn: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request, conn: Connection) -> Response: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() - text = HOME_TEMPLATE.format( - host = self.config.domain, - note = config['note'], - count = len(inboxes), - targets = '
'.join(inbox['domain'] for inbox in inboxes) - ) + text = HOME_TEMPLATE.format( + host = self.config.domain, + note = config['note'], + count = len(inboxes), + targets = '
'.join(inbox['domain'] for inbox in inboxes) + ) return Response.new(text, ctype='html') @@ -98,11 +98,11 @@ class ActorView(View): self.signature: Signature = None self.message: Message = None self.actor: Message = None - self.instance: dict[str, str] = None + self.instance: Row = None self.signer: Signer = None - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: data = Message.new_actor( host = self.config.domain, pubkey = self.app.signer.pubkey @@ -111,37 +111,36 @@ class ActorView(View): return Response.new(data, ctype='activity') - async def post(self, request: Request) -> Response: + async def post(self, request: Request, conn: Connection) -> Response: if response := await self.get_post_data(): return response - with self.database.connection() as conn: - self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() + self.instance = conn.get_inbox(self.actor.shared_inbox) + config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if actor is banned - if conn.get_domain_ban(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.instance: - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - return Response.new_error(401, 'access denied', 'json') + return Response.new_error(401, 'access denied', 'json') - logging.debug('>> payload %s', self.message.to_json(4)) + logging.debug('>> payload %s', self.message.to_json(4)) - asyncio.ensure_future(run_processor(self)) - return Response.new(status = 202) + await run_processor(self, conn) + return Response.new(status = 202) async def get_post_data(self) -> Response | None: @@ -168,7 +167,11 @@ class ActorView(View): logging.verbose('actor not in message') return Response.new_error(400, 'no actor in message', 'json') - self.actor = await self.client.get(self.signature.keyid, sign_headers = True) + self.actor = await self.client.get( + self.signature.keyid, + sign_headers = True, + loads = Message.parse + ) if not self.actor: # ld signatures aren't handled atm, so just ignore it @@ -227,7 +230,7 @@ class ActorView(View): @register_route('/.well-known/webfinger') class WebfingerView(View): - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: try: subject = request.query['resource'] @@ -248,18 +251,18 @@ class WebfingerView(View): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): - async def get(self, request: Request, niversion: str) -> Response: - with self.database.connection() as conn: - inboxes = conn.execute('SELECT * FROM inboxes').all() + # pylint: disable=no-self-use + async def get(self, request: Request, conn: Connection, niversion: str) -> Response: + inboxes = conn.execute('SELECT * FROM inboxes').all() - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay' @@ -269,6 +272,6 @@ class NodeinfoView(View): @register_route('/.well-known/nodeinfo') class WellknownNodeinfoView(View): - async def get(self, request: Request) -> Response: + async def get(self, request: Request, conn: Connection) -> Response: data = WellKnownNodeinfo.new_template(self.config.domain) return Response.new(data, ctype = 'json') diff --git a/requirements.txt b/requirements.txt index 3fb5e40..5239cb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ aiohttp>=3.9.1 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz -cachetools>=5.2.0 click>=8.1.2 +gunicorn==21.1.0 +hiredis==2.3.2 pyyaml>=6.0 -tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz +redis==5.0.1 +tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz importlib_resources==6.1.1;python_version<'3.9'